diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6b0a766391..5c0f2f83aa 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.22.0-rc1"
+__version__ = "0.33.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f8266d1c81..073229b4c4 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -15,15 +15,19 @@
import logging
+from six import itervalues
+
import pymacaroons
+from netaddr import IPAddress
+
from twisted.internet import defer
import synapse.types
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, JoinRules
+from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
-from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -57,16 +61,17 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
- register_cache("token_cache", self.token_cache)
+ register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
- event, context.prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
+ (e.type, e.state_key): e for e in itervalues(auth_events)
}
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
@@ -189,7 +194,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
- access_token = get_access_token_from_request(
+ access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@@ -204,12 +209,12 @@ class Auth(object):
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- "User-Agent",
- default=[""]
+ b"User-Agent",
+ default=[b""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
- user=user,
+ user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@@ -235,13 +240,18 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
- get_access_token_from_request(
+ self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
if app_service is None:
defer.returnValue((None, None))
+ if app_service.ip_range_whitelist:
+ ip_address = IPAddress(self.hs.get_ip_from_request(request))
+ if ip_address not in app_service.ip_range_whitelist:
+ defer.returnValue((None, None))
+
if "user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))
@@ -270,7 +280,11 @@ class Auth(object):
rights (str): The operation being performed; the access token must
allow this.
Returns:
- dict : dict that includes the user and the ID of their access token.
+ Deferred[dict]: dict that includes:
+ `user` (UserID)
+ `is_guest` (bool)
+ `token_id` (int|None): access token id. May be None if guest
+ `device_id` (str|None): device corresponding to access token
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -482,7 +496,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
- logger.warn("Unrecognised access token - not in store: %s" % (token,))
+ logger.warn("Unrecognised access token - not in store.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
@@ -500,12 +514,12 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
- token = get_access_token_from_request(
+ token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
if not service:
- logger.warn("Unrecognised appservice access token: %s" % (token,))
+ logger.warn("Unrecognised appservice access token.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
@@ -519,11 +533,20 @@ class Auth(object):
)
def is_server_admin(self, user):
+ """ Check if the given user is a local server admin.
+
+ Args:
+ user (str): mxid of user to check
+
+ Returns:
+ bool: True if the user is an admin
+ """
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
- auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
@@ -641,7 +664,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", auth_events
+ EventTypes.Aliases, "", power_level_event,
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
@@ -652,67 +675,101 @@ class Auth(object):
" edit its room list entry"
)
+ @staticmethod
+ def has_access_token(request):
+ """Checks if the request has an access_token.
-def has_access_token(request):
- """Checks if the request has an access_token.
+ Returns:
+ bool: False if no access_token was given, True otherwise.
+ """
+ query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ return bool(query_params) or bool(auth_headers)
- Returns:
- bool: False if no access_token was given, True otherwise.
- """
- query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
- return bool(query_params) or bool(auth_headers)
-
-
-def get_access_token_from_request(request, token_not_found_http_status=401):
- """Extracts the access_token from the request.
-
- Args:
- request: The http request.
- token_not_found_http_status(int): The HTTP status code to set in the
- AuthError if the token isn't found. This is used in some of the
- legacy APIs to change the status code to 403 from the default of
- 401 since some of the old clients depended on auth errors returning
- 403.
- Returns:
- str: The access_token
- Raises:
- AuthError: If there isn't an access_token in the request.
- """
+ @staticmethod
+ def get_access_token_from_request(request, token_not_found_http_status=401):
+ """Extracts the access_token from the request.
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
- query_params = request.args.get("access_token")
- if auth_headers:
- # Try the get the access_token from a "Authorization: Bearer"
- # header
- if query_params is not None:
- raise AuthError(
- token_not_found_http_status,
- "Mixing Authorization headers and access_token query parameters.",
- errcode=Codes.MISSING_TOKEN,
- )
- if len(auth_headers) > 1:
- raise AuthError(
- token_not_found_http_status,
- "Too many Authorization headers.",
- errcode=Codes.MISSING_TOKEN,
- )
- parts = auth_headers[0].split(" ")
- if parts[0] == "Bearer" and len(parts) == 2:
- return parts[1]
+ Args:
+ request: The http request.
+ token_not_found_http_status(int): The HTTP status code to set in the
+ AuthError if the token isn't found. This is used in some of the
+ legacy APIs to change the status code to 403 from the default of
+ 401 since some of the old clients depended on auth errors returning
+ 403.
+ Returns:
+ str: The access_token
+ Raises:
+ AuthError: If there isn't an access_token in the request.
+ """
+
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
+ if auth_headers:
+ # Try the get the access_token from a "Authorization: Bearer"
+ # header
+ if query_params is not None:
+ raise AuthError(
+ token_not_found_http_status,
+ "Mixing Authorization headers and access_token query parameters.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ if len(auth_headers) > 1:
+ raise AuthError(
+ token_not_found_http_status,
+ "Too many Authorization headers.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ parts = auth_headers[0].split(" ")
+ if parts[0] == "Bearer" and len(parts) == 2:
+ return parts[1]
+ else:
+ raise AuthError(
+ token_not_found_http_status,
+ "Invalid Authorization header.",
+ errcode=Codes.MISSING_TOKEN,
+ )
else:
- raise AuthError(
- token_not_found_http_status,
- "Invalid Authorization header.",
- errcode=Codes.MISSING_TOKEN,
+ # Try to get the access_token from the query params.
+ if not query_params:
+ raise AuthError(
+ token_not_found_http_status,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
+
+ return query_params[0]
+
+ @defer.inlineCallbacks
+ def check_in_room_or_world_readable(self, room_id, user_id):
+ """Checks that the user is or was in the room or the room is world
+ readable. If it isn't then an exception is raised.
+
+ Returns:
+ Deferred[tuple[str, str|None]]: Resolves to the current membership of
+ the user in the room and the membership event ID of the user. If
+ the user is not in the room and never has been, then
+ `(Membership.JOIN, None)` is returned.
+ """
+
+ try:
+ # check_user_was_in_room will return the most recent membership
+ # event for the user if:
+ # * The user is a non-guest user, and was ever in the room
+ # * The user is a guest user, and has joined the room
+ # else it will throw.
+ member_event = yield self.check_user_was_in_room(room_id, user_id)
+ defer.returnValue((member_event.membership, member_event.event_id))
+ except AuthError:
+ visibility = yield self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
)
- else:
- # Try to get the access_token from the query params.
- if not query_params:
+ if (
+ visibility and
+ visibility.content["history_visibility"] == "world_readable"
+ ):
+ defer.returnValue((Membership.JOIN, None))
+ return
raise AuthError(
- token_not_found_http_status,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN
+ 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
-
- return query_params[0]
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 489efb7f86..4df930c8d1 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -16,6 +16,9 @@
"""Contains constants from the specification."""
+# the "depth" field on events is limited to 2**63 - 1
+MAX_DEPTH = 2**63 - 1
+
class Membership(object):
@@ -73,6 +76,8 @@ class EventTypes(object):
Topic = "m.room.topic"
Name = "m.room.name"
+ ServerACL = "m.room.server_acl"
+
class RejectedReason(object):
AUTH_ERROR = "auth_error"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d0dfa959dc..6074df292f 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -15,9 +15,13 @@
"""Contains exceptions and error codes."""
-import json
import logging
+from six import iteritems
+from six.moves import http_client
+
+from canonicaljson import json
+
logger = logging.getLogger(__name__)
@@ -46,8 +50,11 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+ THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
+ CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
+ CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
class CodeMessageException(RuntimeError):
@@ -135,11 +142,79 @@ class SynapseError(CodeMessageException):
return res
+class ConsentNotGivenError(SynapseError):
+ """The error returned to the client when the user has not consented to the
+ privacy policy.
+ """
+ def __init__(self, msg, consent_uri):
+ """Constructs a ConsentNotGivenError
+
+ Args:
+ msg (str): The human-readable error message
+ consent_url (str): The URL where the user can give their consent
+ """
+ super(ConsentNotGivenError, self).__init__(
+ code=http_client.FORBIDDEN,
+ msg=msg,
+ errcode=Codes.CONSENT_NOT_GIVEN
+ )
+ self._consent_uri = consent_uri
+
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ consent_uri=self._consent_uri
+ )
+
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
+class FederationDeniedError(SynapseError):
+ """An error raised when the server tries to federate with a server which
+ is not on its federation whitelist.
+
+ Attributes:
+ destination (str): The destination which has been denied
+ """
+
+ def __init__(self, destination):
+ """Raised by federation client or server to indicate that we are
+ are deliberately not attempting to contact a given server because it is
+ not on our federation whitelist.
+
+ Args:
+ destination (str): the domain in question
+ """
+
+ self.destination = destination
+
+ super(FederationDeniedError, self).__init__(
+ code=403,
+ msg="Federation denied with %s." % (self.destination,),
+ errcode=Codes.FORBIDDEN,
+ )
+
+
+class InteractiveAuthIncompleteError(Exception):
+ """An error raised when UI auth is not yet complete
+
+ (This indicates we should return a 401 with 'result' as the body)
+
+ Attributes:
+ result (dict): the server response to the request, which should be
+ passed back to the client
+ """
+ def __init__(self, result):
+ super(InteractiveAuthIncompleteError, self).__init__(
+ "Interactive auth not yet complete",
+ )
+ self.result = result
+
+
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
@@ -247,13 +322,13 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
Args:
msg (str): The error message.
- code (int): The error code.
+ code (str): The error code.
kwargs : Additional keys to add to the response.
Returns:
A dict representing the error response JSON.
"""
err = {"error": msg, "errcode": code}
- for key, value in kwargs.iteritems():
+ for key, value in iteritems(kwargs):
err[key] = value
return err
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 83206348e5..25346baa87 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -12,15 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
-from synapse.storage.presence import UserPresenceState
-from synapse.types import UserID, RoomID
-from twisted.internet import defer
-
-import ujson as json
import jsonschema
+from canonicaljson import json
from jsonschema import FormatChecker
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.storage.presence import UserPresenceState
+from synapse.types import RoomID, UserID
+
FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
@@ -411,7 +412,7 @@ class Filter(object):
return room_ids
def filter(self, events):
- return filter(self.check, events)
+ return list(filter(self.check, events))
def limit(self):
return self.filter_json.get("limit", 10)
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 91a33a3402..71347912f1 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,12 @@
# limitations under the License.
"""Contains the URL paths to prefix various aspects of the server with. """
+import hmac
+from hashlib import sha256
+
+from six.moves.urllib.parse import urlencode
+
+from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
@@ -25,3 +32,46 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
+
+
+class ConsentURIBuilder(object):
+ def __init__(self, hs_config):
+ """
+ Args:
+ hs_config (synapse.config.homeserver.HomeServerConfig):
+ """
+ if hs_config.form_secret is None:
+ raise ConfigError(
+ "form_secret not set in config",
+ )
+ if hs_config.public_baseurl is None:
+ raise ConfigError(
+ "public_baseurl not set in config",
+ )
+
+ self._hmac_secret = hs_config.form_secret.encode("utf-8")
+ self._public_baseurl = hs_config.public_baseurl
+
+ def build_user_consent_uri(self, user_id):
+ """Build a URI which we can give to the user to do their privacy
+ policy consent
+
+ Args:
+ user_id (str): mxid or username of user
+
+ Returns
+ (str) the URI where the user can do consent
+ """
+ mac = hmac.new(
+ key=self._hmac_secret,
+ msg=user_id,
+ digestmod=sha256,
+ ).hexdigest()
+ consent_uri = "%s_matrix/consent?%s" % (
+ self._public_baseurl,
+ urlencode({
+ "u": user_id,
+ "h": mac
+ }),
+ )
+ return consent_uri
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 9c2b627590..3b6b9368b8 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -14,9 +14,11 @@
# limitations under the License.
import sys
+
+from synapse import python_dependencies # noqa: E402
+
sys.dont_write_bytecode = True
-from synapse import python_dependencies # noqa: E402
try:
python_dependencies.check_requirements()
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
new file mode 100644
index 0000000000..391bd14c5c
--- /dev/null
+++ b/synapse/app/_base.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import logging
+import sys
+
+from daemonize import Daemonize
+
+from twisted.internet import error, reactor
+
+from synapse.util import PreserveLoggingContext
+from synapse.util.rlimit import change_resource_limit
+
+try:
+ import affinity
+except Exception:
+ affinity = None
+
+
+logger = logging.getLogger(__name__)
+
+
+def start_worker_reactor(appname, config):
+ """ Run the reactor in the main process
+
+ Daemonizes if necessary, and then configures some resources, before starting
+ the reactor. Pulls configuration from the 'worker' settings in 'config'.
+
+ Args:
+ appname (str): application name which will be sent to syslog
+ config (synapse.config.Config): config object
+ """
+
+ logger = logging.getLogger(config.worker_app)
+
+ start_reactor(
+ appname,
+ config.soft_file_limit,
+ config.gc_thresholds,
+ config.worker_pid_file,
+ config.worker_daemonize,
+ config.worker_cpu_affinity,
+ logger,
+ )
+
+
+def start_reactor(
+ appname,
+ soft_file_limit,
+ gc_thresholds,
+ pid_file,
+ daemonize,
+ cpu_affinity,
+ logger,
+):
+ """ Run the reactor in the main process
+
+ Daemonizes if necessary, and then configures some resources, before starting
+ the reactor
+
+ Args:
+ appname (str): application name which will be sent to syslog
+ soft_file_limit (int):
+ gc_thresholds:
+ pid_file (str): name of pid file to write to if daemonize is True
+ daemonize (bool): true to run the reactor in a background process
+ cpu_affinity (int|None): cpu affinity mask
+ logger (logging.Logger): logger instance to pass to Daemonize
+ """
+
+ def run():
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ with PreserveLoggingContext():
+ logger.info("Running")
+ if cpu_affinity is not None:
+ if not affinity:
+ quit_with_error(
+ "Missing package 'affinity' required for cpu_affinity\n"
+ "option\n\n"
+ "Install by running:\n\n"
+ " pip install affinity\n\n"
+ )
+ logger.info("Setting CPU affinity to %s" % cpu_affinity)
+ affinity.set_process_affinity_mask(0, cpu_affinity)
+ change_resource_limit(soft_file_limit)
+ if gc_thresholds:
+ gc.set_threshold(*gc_thresholds)
+ reactor.run()
+
+ if daemonize:
+ daemon = Daemonize(
+ app=appname,
+ pid=pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+def quit_with_error(error_string):
+ message_lines = error_string.split("\n")
+ line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
+ sys.stderr.write("*" * line_length + '\n')
+ for line in message_lines:
+ sys.stderr.write(" %s\n" % (line.rstrip(),))
+ sys.stderr.write("*" * line_length + '\n')
+ sys.exit(1)
+
+
+def listen_metrics(bind_addresses, port):
+ """
+ Start Prometheus metrics server.
+ """
+ from synapse.metrics import RegistryProxy
+ from prometheus_client import start_http_server
+
+ for host in bind_addresses:
+ reactor.callInThread(start_http_server, int(port),
+ addr=host, registry=RegistryProxy)
+ logger.info("Metrics now reporting on %s:%d", host, port)
+
+
+def listen_tcp(bind_addresses, port, factory, backlog=50):
+ """
+ Create a TCP socket for a port and several addresses
+ """
+ for address in bind_addresses:
+ try:
+ reactor.listenTCP(
+ port,
+ factory,
+ backlog,
+ address
+ )
+ except error.CannotListenError as e:
+ check_bind_error(e, address, bind_addresses)
+
+
+def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+ """
+ Create an SSL socket for a port and several addresses
+ """
+ for address in bind_addresses:
+ try:
+ reactor.listenSSL(
+ port,
+ factory,
+ context_factory,
+ backlog,
+ address
+ )
+ except error.CannotListenError as e:
+ check_bind_error(e, address, bind_addresses)
+
+
+def check_bind_error(e, address, bind_addresses):
+ """
+ This method checks an exception occurred while binding on 0.0.0.0.
+ If :: is specified in the bind addresses a warning is shown.
+ The exception is still raised otherwise.
+
+ Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
+ because :: binds on both IPv4 and IPv6 (as per RFC 3493).
+ When binding on 0.0.0.0 after :: this can safely be ignored.
+
+ Args:
+ e (Exception): Exception that was caught.
+ address (str): Address on which binding was attempted.
+ bind_addresses (list): Addresses on which the service listens.
+ """
+ if address == '0.0.0.0' and '::' in bind_addresses:
+ logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+ else:
+ raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 9a476efa63..9a37384fb7 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -13,38 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
-from synapse.server import HomeServer
+import synapse
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse import events
-
-from twisted.internet import reactor
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
-
logger = logging.getLogger("synapse.app.appservice")
@@ -56,19 +51,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@@ -82,21 +64,21 @@ class AppserviceServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
-
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse appservice now listening on port %d", port)
@@ -105,18 +87,22 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -136,9 +122,14 @@ class ASReplicationHandler(ReplicationClientHandler):
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
- preserve_fn(
- self.appservice_handler.notify_interested_services
- )(max_stream_id)
+ run_in_background(self._notify_app_services, max_stream_id)
+
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
def start(config_options):
@@ -181,36 +172,13 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-appservice",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-appservice", config)
if __name__ == '__main__':
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 09bc1935f1..398bb36602 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -13,46 +13,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+import synapse
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
+from synapse.crypto import context_factory
from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.room import PublicRoomListRestServlet
+from synapse.rest.client.v1.room import (
+ JoinedRoomMemberListRestServlet,
+ PublicRoomListRestServlet,
+ RoomEventContextServlet,
+ RoomMemberListRestServlet,
+ RoomStateRestServlet,
+)
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
logger = logging.getLogger("synapse.app.client_reader")
@@ -72,19 +72,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@@ -98,10 +85,16 @@ class ClientReaderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
+
PublicRoomListRestServlet(self).register(resource)
+ RoomMemberListRestServlet(self).register(resource)
+ JoinedRoomMemberListRestServlet(self).register(resource)
+ RoomStateRestServlet(self).register(resource)
+ RoomEventContextServlet(self).register(resource)
+
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
@@ -109,19 +102,19 @@ class ClientReaderServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse client reader now listening on port %d", port)
@@ -130,18 +123,22 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -180,39 +177,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-client-reader",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-client-reader", config)
if __name__ == '__main__':
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
new file mode 100644
index 0000000000..374f115644
--- /dev/null
+++ b/synapse/app/event_creator.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+
+import synapse
+from synapse import events
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.room import (
+ JoinRoomAliasServlet,
+ RoomMembershipRestServlet,
+ RoomSendEventRestServlet,
+ RoomStateEventRestServlet,
+)
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.event_creator")
+
+
+class EventCreatorSlavedStore(
+ DirectoryStore,
+ TransactionStore,
+ SlavedProfileStore,
+ SlavedAccountDataStore,
+ SlavedPusherStore,
+ SlavedReceiptsStore,
+ SlavedPushRuleStore,
+ SlavedDeviceStore,
+ SlavedClientIpStore,
+ SlavedApplicationServiceStore,
+ SlavedEventStore,
+ SlavedRegistrationStore,
+ RoomStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class EventCreatorServer(HomeServer):
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ RoomSendEventRestServlet(self).register(resource)
+ RoomMembershipRestServlet(self).register(resource)
+ RoomStateEventRestServlet(self).register(resource)
+ JoinRoomAliasServlet(self).register(resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
+ )
+ )
+
+ logger.info("Synapse event creator now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse event creator", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.event_creator"
+
+ assert config.worker_replication_http_port is not None
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = EventCreatorServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.start_listening(config.worker_listeners)
+
+ def start():
+ ss.get_state_handler().start_caching()
+ ss.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ _base.start_worker_reactor("synapse-event-creator", config)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index eb392e1c9d..7af00b8bcf 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -13,43 +13,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+import synapse
+from synapse import events
+from synapse.api.urls import FEDERATION_PREFIX
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.api.urls import FEDERATION_PREFIX
-from synapse.federation.transport.server import TransportLayerServer
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
logger = logging.getLogger("synapse.app.federation_reader")
@@ -66,19 +60,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@@ -92,25 +73,25 @@ class FederationReaderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse federation reader now listening on port %d", port)
@@ -119,18 +100,22 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -169,39 +154,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-federation-reader",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-federation-reader", config)
if __name__ == '__main__':
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 03327dc47a..18469013fa 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -13,44 +13,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
-from synapse.server import HomeServer
+import synapse
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
-from synapse.http.site import SynapseSite
from synapse.federation import send_queue
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
-
logger = logging.getLogger("synapse.app.federation_sender")
@@ -83,19 +78,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@@ -109,21 +91,21 @@ class FederationSenderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
-
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse federation_sender now listening on port %d", port)
@@ -132,18 +114,22 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -213,36 +199,12 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
-
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-federation-sender",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-federation-sender", config)
class FederationSenderHandler(object):
@@ -277,7 +239,7 @@ class FederationSenderHandler(object):
# presence, typing, etc.
if stream_name == "federation":
send_queue.process_rows_for_federation(self.federation_sender, rows)
- preserve_fn(self.update_token)(token)
+ run_in_background(self.update_token, token)
# We also need to poke the federation sender when new events happen
elif stream_name == "events":
@@ -285,19 +247,22 @@ class FederationSenderHandler(object):
@defer.inlineCallbacks
def update_token(self, token):
- self.federation_position = token
-
- # We linearize here to ensure we don't have races updating the token
- with (yield self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- yield self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (yield self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ yield self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(self.federation_position)
- self._last_ack = self.federation_position
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(self.federation_position)
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
if __name__ == '__main__':
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
new file mode 100644
index 0000000000..b5f78f4640
--- /dev/null
+++ b/synapse/app/frontend_proxy.py
@@ -0,0 +1,235 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+import synapse
+from synapse import events
+from synapse.api.errors import SynapseError
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha._base import client_v2_patterns
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.frontend_proxy")
+
+
+class KeyUploadServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(KeyUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ if device_id is not None:
+ # passing the device_id here is deprecated; however, we allow it
+ # for now for compatibility with older clients.
+ if (requester.device_id is not None and
+ device_id != requester.device_id):
+ logger.warning("Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id, device_id)
+ else:
+ device_id = requester.device_id
+
+ if device_id is None:
+ raise SynapseError(
+ 400,
+ "To upload keys, you must pass device_id when authenticating"
+ )
+
+ if body:
+ # They're actually trying to upload something, proxy to main synapse.
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
+ headers = {
+ "Authorization": auth_headers,
+ }
+ result = yield self.http_client.post_json_get_json(
+ self.main_uri + request.uri,
+ body,
+ headers=headers,
+ )
+
+ defer.returnValue((200, result))
+ else:
+ # Just interested in counts.
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class FrontendProxySlavedStore(
+ SlavedDeviceStore,
+ SlavedClientIpStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class FrontendProxyServer(HomeServer):
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+ KeyUploadServlet(self).register(resource)
+ resources.update({
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ })
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
+ )
+ )
+
+ logger.info("Synapse client reader now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ )
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse frontend proxy", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.frontend_proxy"
+
+ assert config.worker_main_http_uri is not None
+
+ setup_logging(config, use_worker_options=True)
+
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = FrontendProxyServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.start_listening(config.worker_listeners)
+
+ def start():
+ ss.get_state_handler().start_caching()
+ ss.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ _base.start_worker_reactor("synapse-frontend-proxy", config)
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 081e7cce59..2ad1beb8d8 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,61 +13,62 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import synapse
-
import gc
import logging
import os
import sys
-import synapse.config.logger
-from synapse.config._base import ConfigError
-
-from synapse.python_dependencies import (
- check_requirements, CONDITIONAL_REQUIREMENTS
-)
-
-from synapse.rest import ClientRestResource
-from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
-from synapse.storage import are_all_users_on_domain
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
-
-from synapse.server import HomeServer
+from six import iteritems
-from twisted.internet import reactor, defer
from twisted.application import service
-from twisted.web.resource import Resource, EncodingResourceWrapper
-from twisted.web.static import File
+from twisted.internet import defer, reactor
+from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory
-from synapse.http.server import RootRedirect
-from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.rest.key.v1.server_key_resource import LocalKey
-from synapse.rest.key.v2 import KeyApiV2Resource
+from twisted.web.static import File
+
+import synapse
+import synapse.config.logger
+from synapse import events
from synapse.api.urls import (
- FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
- SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
+ CONTENT_REPO_PREFIX,
+ FEDERATION_PREFIX,
+ LEGACY_MEDIA_PREFIX,
+ MEDIA_PREFIX,
+ SERVER_KEY_PREFIX,
SERVER_KEY_V2_PREFIX,
+ STATIC_PREFIX,
+ WEB_CLIENT_PREFIX,
)
+from synapse.app import _base
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.metrics import register_memory_metrics
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.federation.transport.server import TransportLayerServer
-
-from synapse.util.rlimit import change_resource_limit
-from synapse.util.versionstring import get_version_string
+from synapse.http.additional_resource import AdditionalResource
+from synapse.http.server import RootRedirect
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.module_api import ModuleApi
+from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements
+from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.rest import ClientRestResource
+from synapse.rest.key.v1.server_key_resource import LocalKey
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.media.v0.content_repository import ContentRepoResource
+from synapse.server import HomeServer
+from synapse.storage import are_all_users_on_domain
+from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-
-from synapse.http.site import SynapseSite
-
-from synapse import events
-
-from daemonize import Daemonize
+from synapse.util.module_loader import load_module
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.homeserver")
@@ -119,87 +120,132 @@ class SynapseHomeServer(HomeServer):
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
- if name == "client":
- client_resource = ClientRestResource(self)
- if res["compress"]:
- client_resource = gz_wrap(client_resource)
-
- resources.update({
- "/_matrix/client/api/v1": client_resource,
- "/_matrix/client/r0": client_resource,
- "/_matrix/client/unstable": client_resource,
- "/_matrix/client/v2_alpha": client_resource,
- "/_matrix/client/versions": client_resource,
- })
-
- if name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
-
- if name in ["static", "client"]:
- resources.update({
- STATIC_PREFIX: File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- ),
- })
-
- if name in ["media", "federation", "client"]:
- media_repo = MediaRepositoryResource(self)
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
-
- if name in ["keys", "federation"]:
- resources.update({
- SERVER_KEY_PREFIX: LocalKey(self),
- SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
- })
-
- if name == "webclient":
- resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
-
- if name == "metrics" and self.get_config().enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources.update(self._configure_named_resource(
+ name, res.get("compress", False),
+ ))
+
+ additional_resources = listener_config.get("additional_resources", {})
+ logger.debug("Configuring additional resources: %r",
+ additional_resources)
+ module_api = ModuleApi(self, self.get_auth_handler())
+ for path, resmodule in additional_resources.items():
+ handler_cls, config = load_module(resmodule)
+ handler = handler_cls(config, module_api)
+ resources[path] = AdditionalResource(self, handler.handle_request)
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
- root_resource = Resource()
+ root_resource = NoResource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
- for address in bind_addresses:
- reactor.listenSSL(
- port,
- SynapseSite(
- "synapse.access.https.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- self.tls_server_context_factory,
- interface=address
- )
+ listen_ssl(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.https.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
+ ),
+ self.tls_server_context_factory,
+ )
+
else:
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse now listening on port %d", port)
+ def _configure_named_resource(self, name, compress=False):
+ """Build a resource map for a named resource
+
+ Args:
+ name (str): named resource: one of "client", "federation", etc
+ compress (bool): whether to enable gzip compression for this
+ resource
+
+ Returns:
+ dict[str, Resource]: map from path to HTTP resource
+ """
+ resources = {}
+ if name == "client":
+ client_resource = ClientRestResource(self)
+ if compress:
+ client_resource = gz_wrap(client_resource)
+
+ resources.update({
+ "/_matrix/client/api/v1": client_resource,
+ "/_matrix/client/r0": client_resource,
+ "/_matrix/client/unstable": client_resource,
+ "/_matrix/client/v2_alpha": client_resource,
+ "/_matrix/client/versions": client_resource,
+ })
+
+ if name == "consent":
+ from synapse.rest.consent.consent_resource import ConsentResource
+ consent_resource = ConsentResource(self)
+ if compress:
+ consent_resource = gz_wrap(consent_resource)
+ resources.update({
+ "/_matrix/consent": consent_resource,
+ })
+
+ if name == "federation":
+ resources.update({
+ FEDERATION_PREFIX: TransportLayerServer(self),
+ })
+
+ if name in ["static", "client"]:
+ resources.update({
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ ),
+ })
+
+ if name in ["media", "federation", "client"]:
+ if self.get_config().enable_media_repo:
+ media_repo = self.get_media_repository_resource()
+ resources.update({
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path
+ ),
+ })
+ elif name == "media":
+ raise ConfigError(
+ "'media' resource conflicts with enable_media_repo=False",
+ )
+
+ if name in ["keys", "federation"]:
+ resources.update({
+ SERVER_KEY_PREFIX: LocalKey(self),
+ SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
+ })
+
+ if name == "webclient":
+ resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
+
+ if name == "metrics" and self.get_config().enable_metrics:
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+
+ if name == "replication":
+ resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
+
+ return resources
+
def start_listening(self):
config = self.get_config()
@@ -207,18 +253,15 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
elif listener["type"] == "replication":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
@@ -229,6 +272,13 @@ class SynapseHomeServer(HomeServer):
reactor.addSystemEventTrigger(
"before", "shutdown", server_listener.stopListening,
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -248,29 +298,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
-
-def quit_with_error(error_string):
- message_lines = error_string.split("\n")
- line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
- sys.stderr.write("*" * line_length + '\n')
- for line in message_lines:
- sys.stderr.write(" %s\n" % (line.rstrip(),))
- sys.stderr.write("*" * line_length + '\n')
- sys.exit(1)
-
def setup(config_options):
"""
@@ -300,11 +327,6 @@ def setup(config_options):
# check any extra requirements we have now we have a config
check_requirements(config)
- version_string = "Synapse/" + get_version_string(synapse)
-
- logger.info("Server hostname: %s", config.server_name)
- logger.info("Server version: %s", version_string)
-
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
@@ -317,7 +339,7 @@ def setup(config_options):
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
- version_string=version_string,
+ version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
@@ -349,9 +371,7 @@ def setup(config_options):
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
- hs.get_replication_layer().start_get_pdu_cache()
-
- register_memory_metrics(hs)
+ hs.get_federation_client().start_get_pdu_cache()
reactor.callWhenRunning(start)
@@ -403,6 +423,10 @@ def run(hs):
stats = {}
+ # Contains the list of processes we will be monitoring
+ # currently either 0 or 1
+ stats_process = []
+
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@@ -419,6 +443,10 @@ def run(hs):
total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
+ daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ for name, count in iteritems(daily_user_type_results):
+ stats["daily_user_type_" + name] = count
+
room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
@@ -426,8 +454,21 @@ def run(hs):
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in iteritems(r30_results):
+ stats["r30_users_" + name] = count
+
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ if len(stats_process) > 0:
+ stats["memory_rss"] = 0
+ stats["cpu_average"] = 0
+ for process in stats_process:
+ stats["memory_rss"] += process.memory_info().rss
+ stats["cpu_average"] += int(process.cpu_percent(interval=None))
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
@@ -438,45 +479,56 @@ def run(hs):
except Exception as e:
logger.warn("Error reporting stats: %s", e)
+ def performance_stats_init():
+ try:
+ import psutil
+ process = psutil.Process()
+ # Ensure we can fetch both, and make the initial request for cpu_percent
+ # so the next request will use this as the initial point.
+ process.memory_info().rss
+ process.cpu_percent(interval=None)
+ logger.info("report_stats can use psutil")
+ stats_process.append(process)
+ except (ImportError, AttributeError):
+ logger.warn(
+ "report_stats enabled but psutil is not installed or incorrect version."
+ " Disabling reporting of memory/cpu stats."
+ " Ensuring psutil is available will help matrix.org track performance"
+ " changes across releases."
+ )
+
+ def generate_user_daily_visit_stats():
+ hs.get_datastore().generate_user_daily_visits()
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
+
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
# We wait 5 minutes to send the first set of stats as the server can
# be quite busy the first few minutes
clock.call_later(5 * 60, phone_stats_home)
- def in_thread():
- # Uncomment to enable tracing of log context changes.
- # sys.settrace(logcontext_tracer)
-
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- change_resource_limit(hs.config.soft_file_limit)
- if hs.config.gc_thresholds:
- gc.set_threshold(*hs.config.gc_thresholds)
- reactor.run()
-
- if hs.config.daemonize:
-
- if hs.config.print_pidfile:
- print (hs.config.pid_file)
-
- daemon = Daemonize(
- app="synapse-homeserver",
- pid=hs.config.pid_file,
- action=lambda: in_thread(),
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
-
- daemon.start()
- else:
- in_thread()
+ if hs.config.daemonize and hs.config.print_pidfile:
+ print (hs.config.pid_file)
+
+ _base.start_reactor(
+ "synapse-homeserver",
+ hs.config.soft_file_limit,
+ hs.config.gc_thresholds,
+ hs.config.pid_file,
+ hs.config.daemonize,
+ hs.config.cpu_affinity,
+ logger,
+ )
def main():
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index f57ec784fe..749bbf37d0 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -13,14 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+import synapse
+from synapse import events
+from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -28,31 +37,13 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.api.urls import (
- CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
-)
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
logger = logging.getLogger("synapse.app.media_repository")
@@ -69,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@@ -95,9 +73,9 @@ class MediaRepositoryServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "media":
- media_repo = MediaRepositoryResource(self)
+ media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
@@ -106,19 +84,19 @@ class MediaRepositoryServer(HomeServer):
),
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse media repository now listening on port %d", port)
@@ -127,18 +105,22 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -159,6 +141,13 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository"
+ if config.enable_media_repo:
+ _base.quit_with_error(
+ "enable_media_repo must be disabled in the main synapse process\n"
+ "before the media repo can be run in a separate worker.\n"
+ "Please add ``enable_media_repo: false`` to the main config\n"
+ )
+
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -177,39 +166,15 @@ def start(config_options):
)
ss.setup()
- ss.get_handlers()
ss.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-media-repository",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-media-repository", config)
if __name__ == '__main__':
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index f9114acfcb..9295a51d5b 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,41 +13,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
-import synapse
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
-from synapse.server import HomeServer
+import synapse
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.storage.roommember import RoomMemberStore
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.storage.engines import create_engine
+from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn, \
- PreserveLoggingContext
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
-
logger = logging.getLogger("synapse.app.pusher")
@@ -83,25 +76,8 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
- who_forgot_in_room = (
- RoomMemberStore.__dict__["who_forgot_in_room"]
- )
-
class PusherServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
@@ -118,21 +94,21 @@ class PusherServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
-
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse pusher now listening on port %d", port)
@@ -141,18 +117,22 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -170,24 +150,27 @@ class PusherReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
- preserve_fn(self.poke_pushers)(stream_name, token, rows)
+ run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
def poke_pushers(self, stream_name, token, rows):
- if stream_name == "pushers":
- for row in rows:
- if row.deleted:
- yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
- else:
- yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
- elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
- token, token,
- )
- elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
- token, token, set(row.room_id for row in rows)
- )
+ try:
+ if stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ elif stream_name == "events":
+ yield self.pusher_pool.on_new_notifications(
+ token, token,
+ )
+ elif stream_name == "receipts":
+ yield self.pusher_pool.on_new_receipts(
+ token, token, set(row.room_id for row in rows)
+ )
+ except Exception:
+ logger.exception("Error poking pushers")
def stop_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@@ -244,18 +227,6 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
@@ -263,18 +234,7 @@ def start(config_options):
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-pusher",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-pusher", config)
if __name__ == '__main__':
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 4bdd99a966..26b9ec85f2 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -13,78 +13,74 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
+import logging
+import sys
-import synapse
+from six import iteritems
+
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+import synapse
from synapse.api.constants import EventTypes
+from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.handlers.presence import PresenceHandler, get_interested_parties
-from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.rest.client.v2_alpha import sync
-from synapse.rest.client.v1 import events
-from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
-from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
+from synapse.rest.client.v2_alpha import sync
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import contextlib
-import gc
-
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore(
- SlavedPushRuleStore,
- SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
+ SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
+ SlavedPushRuleStore,
+ SlavedEventStore,
SlavedClientIpStore,
RoomStore,
BaseSlavedStore,
):
- who_forgot_in_room = (
- RoomMemberStore.__dict__["who_forgot_in_room"]
- )
-
did_forget = (
RoomMemberStore.__dict__["did_forget"]
)
@@ -219,7 +215,7 @@ class SynchrotronPresence(object):
def get_currently_syncing_users(self):
return [
- user_id for user_id, count in self.user_to_num_current_syncs.iteritems()
+ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
if count > 0
]
@@ -250,19 +246,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@@ -276,7 +259,7 @@ class SynchrotronServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
@@ -290,19 +273,19 @@ class SynchrotronServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse synchrotron now listening on port %d", port)
@@ -311,18 +294,22 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -344,15 +331,13 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
+ # NB this is a SynchrotronPresence, not a normal PresenceHandler
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
- self.presence_handler.sync_callback = self.send_user_sync
-
def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
-
- preserve_fn(self.process_and_notify)(stream_name, token, rows)
+ run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(SyncReplicationHandler, self).get_streams_to_replicate()
@@ -364,51 +349,58 @@ class SyncReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks
def process_and_notify(self, stream_name, token, rows):
- if stream_name == "events":
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- event = yield self.store.get_event(row.event_id)
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
+ try:
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ event = yield self.store.get_event(row.event_id)
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "push_rules":
- self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name in ("account_data", "tag_account_data",):
- self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "typing":
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "to_device":
- entities = [row.entity for row in rows if row.entity.startswith("@")]
- if entities:
+ elif stream_name in ("account_data", "tag_account_data",):
self.notifier.on_new_event(
- "to_device_key", token, users=entities,
+ "account_data_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "device_lists":
- all_room_ids = set()
- for row in rows:
- room_ids = yield self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- "device_list_key", token, rooms=all_room_ids,
- )
- elif stream_name == "presence":
- yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event(
+ "to_device_key", token, users=entities,
+ )
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = yield self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event(
+ "device_list_key", token, rooms=all_room_ids,
+ )
+ elif stream_name == "presence":
+ yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows],
+ )
+ except Exception:
+ logger.exception("Error processing replication")
def start(config_options):
@@ -440,36 +432,13 @@ def start(config_options):
ss.setup()
ss.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ss.get_datastore().start_profiling()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-synchrotron",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-synchrotron", config)
if __name__ == '__main__':
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 3bd7ef7bba..d658f967ba 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -16,16 +16,19 @@
import argparse
import collections
+import errno
import glob
import os
import os.path
import signal
import subprocess
import sys
-import yaml
-import errno
import time
+from six import iteritems
+
+import yaml
+
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
@@ -38,7 +41,7 @@ def pid_running(pid):
try:
os.kill(pid, 0)
return True
- except OSError, err:
+ except OSError as err:
if err.errno == errno.EPERM:
return True
return False
@@ -98,7 +101,7 @@ def stop(pidfile, app):
try:
os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN)
- except OSError, err:
+ except OSError as err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
elif err.errno == errno.EPERM:
@@ -171,6 +174,10 @@ def main():
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
+ cache_factors = config.get("synctl_cache_factors", {})
+ for cache_name, factor in iteritems(cache_factors):
+ os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
+
worker_configfiles = []
if options.worker:
start_stop_synapse = False
@@ -184,6 +191,9 @@ def main():
worker_configfiles.append(worker_configfile)
if options.all_processes:
+ # To start the main synapse with -a you need to add a worker file
+ # with worker_app == "synapse.app.homeserver"
+ start_stop_synapse = False
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
@@ -200,11 +210,29 @@ def main():
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
- worker_pidfile = worker_config["worker_pid_file"]
- worker_daemonize = worker_config["worker_daemonize"]
- assert worker_daemonize, "In config %r: expected '%s' to be True" % (
- worker_configfile, "worker_daemonize")
- worker_cache_factor = worker_config.get("synctl_cache_factor")
+ if worker_app == "synapse.app.homeserver":
+ # We need to special case all of this to pick up options that may
+ # be set in the main config file or in this worker config file.
+ worker_pidfile = (
+ worker_config.get("pid_file")
+ or pidfile
+ )
+ worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
+ daemonize = worker_config.get("daemonize") or config.get("daemonize")
+ assert daemonize, "Main process must have daemonize set to true"
+
+ # The master process doesn't support using worker_* config.
+ for key in worker_config:
+ if key == "worker_app": # But we allow worker_app
+ continue
+ assert not key.startswith("worker_"), \
+ "Main process cannot use worker_* config"
+ else:
+ worker_pidfile = worker_config["worker_pid_file"]
+ worker_daemonize = worker_config["worker_daemonize"]
+ assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+ worker_configfile, "worker_daemonize")
+ worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
@@ -231,6 +259,7 @@ def main():
for running_pid in running_pids:
while pid_running(running_pid):
time.sleep(0.2)
+ write("All processes exited; now restarting...")
if action == "start" or action == "restart":
if start_stop_synapse:
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 8c6300db9d..637a89530a 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -14,16 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse
+import logging
+import sys
-from synapse.server import HomeServer
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+import synapse
+from synapse import events
+from synapse.app import _base
from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
-from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -31,25 +38,14 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v2_alpha import user_directory
+from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from synapse import events
-
-from twisted.internet import reactor
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
logger = logging.getLogger("synapse.app.user_dir")
@@ -98,19 +94,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
@@ -124,7 +107,7 @@ class UserDirectoryServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
user_directory.register_servlets(self, resource)
@@ -135,19 +118,19 @@ class UserDirectoryServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
-
- for address in bind_addresses:
- reactor.listenTCP(
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- ),
- interface=address
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
)
+ )
logger.info("Synapse user_dir now listening on port %d", port)
@@ -156,18 +139,22 @@ class UserDirectoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
- bind_addresses = listener["bind_addresses"]
-
- for address in bind_addresses:
- reactor.listenTCP(
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- ),
- interface=address
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
)
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "enable_metrics is not True!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -187,7 +174,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
stream_name, token, rows
)
if stream_name == "current_state_deltas":
- preserve_fn(self.user_directory.notify_new_event)()
+ run_in_background(self._notify_directory)
+
+ @defer.inlineCallbacks
+ def _notify_directory(self):
+ try:
+ yield self.user_directory.notify_new_event()
+ except Exception:
+ logger.exception("Error notifiying user directory of state update")
def start(config_options):
@@ -233,36 +227,13 @@ def start(config_options):
ps.setup()
ps.start_listening(config.worker_listeners)
- def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- change_resource_limit(config.soft_file_limit)
- if config.gc_thresholds:
- gc.set_threshold(*config.gc_thresholds)
- reactor.run()
-
def start():
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
- if config.worker_daemonize:
- daemon = Daemonize(
- app="synapse-user-dir",
- pid=config.worker_pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ _base.start_worker_reactor("synapse-user-dir", config)
if __name__ == '__main__':
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b989007314..57ed8a3ca2 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -12,13 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import logging
+import re
+
+from six import string_types
from twisted.internet import defer
-import logging
-import re
+from synapse.api.constants import EventTypes
+from synapse.types import GroupID, get_domain_from_id
+from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -81,14 +84,17 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
- def __init__(self, token, url=None, namespaces=None, hs_token=None,
- sender=None, id=None, protocols=None, rate_limited=True):
+ def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
+ sender=None, id=None, protocols=None, rate_limited=True,
+ ip_range_whitelist=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
+ self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
+ self.ip_range_whitelist = ip_range_whitelist
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -125,8 +131,26 @@ class ApplicationService(object):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
+ group_id = regex_obj.get("group_id")
+ if group_id:
+ if not isinstance(group_id, str):
+ raise ValueError(
+ "Expected string for 'group_id' in ns '%s'" % ns
+ )
+ try:
+ GroupID.from_string(group_id)
+ except Exception:
+ raise ValueError(
+ "Expected valid group ID for 'group_id' in ns '%s'" % ns
+ )
+
+ if get_domain_from_id(group_id) != self.server_name:
+ raise ValueError(
+ "Expected 'group_id' to be this host in ns '%s'" % ns
+ )
+
regex = regex_obj.get("regex")
- if isinstance(regex, basestring):
+ if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError(
@@ -251,8 +275,27 @@ class ApplicationService(object):
if regex_obj["exclusive"]
]
+ def get_groups_for_user(self, user_id):
+ """Get the groups that this user is associated with by this AS
+
+ Args:
+ user_id (str): The ID of the user.
+
+ Returns:
+ iterable[str]: an iterable that yields group_id strings.
+ """
+ return (
+ regex_obj["group_id"]
+ for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+ if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+ )
+
def is_rate_limited(self):
return self.rate_limited
def __str__(self):
- return "ApplicationService: %s" % (self.__dict__,)
+ # copy dictionary and redact token fields so they don't get logged
+ dict_copy = self.__dict__.copy()
+ dict_copy["token"] = "<redacted>"
+ dict_copy["hs_token"] = "<redacted>"
+ return "ApplicationService: %s" % (dict_copy,)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6893610e71..6980e5890e 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -12,20 +12,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+
+from prometheus_client import Counter
+
from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
-from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
-from synapse.util.caches.response_cache import ResponseCache
+from synapse.http.client import SimpleHttpClient
from synapse.types import ThirdPartyInstanceID
-
-import logging
-import urllib
+from synapse.util.caches.response_cache import ResponseCache
logger = logging.getLogger(__name__)
+sent_transactions_counter = Counter(
+ "synapse_appservice_api_sent_transactions",
+ "Number of /transactions/ requests sent",
+ ["service"]
+)
+
+failed_transactions_counter = Counter(
+ "synapse_appservice_api_failed_transactions",
+ "Number of /transactions/ requests that failed to send",
+ ["service"]
+)
+
+sent_events_counter = Counter(
+ "synapse_appservice_api_sent_events",
+ "Number of events sent to the AS",
+ ["service"]
+)
HOUR_IN_MS = 60 * 60 * 1000
@@ -72,7 +91,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
+ self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
+ timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks
def query_user(self, service, user_id):
@@ -192,9 +212,7 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(None)
key = (service.id, protocol)
- return self.protocol_meta_cache.get(key) or (
- self.protocol_meta_cache.set(key, _get())
- )
+ return self.protocol_meta_cache.wrap(key, _get)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
@@ -220,12 +238,15 @@ class ApplicationServiceApi(SimpleHttpClient):
args={
"access_token": service.hs_token
})
+ sent_transactions_counter.labels(service.id).inc()
+ sent_events_counter.labels(service.id).inc(len(events))
defer.returnValue(True)
return
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
+ failed_transactions_counter.labels(service.id).inc()
defer.returnValue(False)
def _serialize(self, events):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 68a9de17b8..2430814796 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,14 +48,14 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
+import logging
+
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
-import logging
-
logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class _ServiceQueuer(object):
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
- preserve_fn(self._send_request)(service)
+ run_in_background(self._send_request, service)
@defer.inlineCallbacks
def _send_request(self, service):
@@ -123,7 +123,7 @@ class _ServiceQueuer(object):
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
- except:
+ except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
@@ -152,10 +152,10 @@ class _TransactionController(object):
if sent:
yield txn.complete(self.store)
else:
- preserve_fn(self._start_recoverer)(service)
- except Exception as e:
- logger.exception(e)
- preserve_fn(self._start_recoverer)(service)
+ run_in_background(self._start_recoverer, service)
+ except Exception:
+ logger.exception("Error creating appservice transaction")
+ run_in_background(self._start_recoverer, service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
@@ -176,17 +176,20 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- logger.info(
- "Application service falling behind. Starting recoverer. AS ID %s",
- service.id
- )
- recoverer = self.recoverer_fn(service, self.on_recovered)
- self.add_recoverers([recoverer])
- recoverer.recover()
+ try:
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ logger.info(
+ "Application service falling behind. Starting recoverer. AS ID %s",
+ service.id
+ )
+ recoverer = self.recoverer_fn(service, self.on_recovered)
+ self.add_recoverers([recoverer])
+ recoverer.recover()
+ except Exception:
+ logger.exception("Error starting AS recoverer")
@defer.inlineCallbacks
def _is_service_up(self, service):
diff --git a/synapse/config/__init__.py b/synapse/config/__init__.py
index bfebb0f644..f2a5a41e92 100644
--- a/synapse/config/__init__.py
+++ b/synapse/config/__init__.py
@@ -12,3 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from ._base import ConfigError
+
+# export ConfigError if somebody does import *
+# this is largely a fudge to stop PEP8 moaning about the import
+__all__ = ["ConfigError"]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1ab5593c6e..3d2e90dd5b 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -16,9 +16,12 @@
import argparse
import errno
import os
-import yaml
from textwrap import dedent
+from six import integer_types
+
+import yaml
+
class ConfigError(Exception):
pass
@@ -49,7 +52,7 @@ Missing mandatory `server_name` config option.
class Config(object):
@staticmethod
def parse_size(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
@@ -61,7 +64,7 @@ class Config(object):
@staticmethod
def parse_duration(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
second = 1000
minute = 60 * second
@@ -82,21 +85,37 @@ class Config(object):
return os.path.abspath(file_path) if file_path else file_path
@classmethod
+ def path_exists(cls, file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+ @classmethod
def check_file(cls, file_path, config_name):
if file_path is None:
raise ConfigError(
"Missing config for %s."
- " You must specify a path for the config file. You can "
- "do this with the -c or --config-path option. "
- "Adding --generate-config along with --server-name "
- "<server name> will generate a config file at the given path."
% (config_name,)
)
- if not os.path.exists(file_path):
+ try:
+ os.stat(file_path)
+ except OSError as e:
raise ConfigError(
- "File %s config for %s doesn't exist."
- " Try running again with --generate-config"
- % (file_path, config_name,)
+ "Error accessing file '%s' (config for %s): %s"
+ % (file_path, config_name, e.strerror)
)
return cls.abspath(file_path)
@@ -248,7 +267,7 @@ class Config(object):
" -c CONFIG-FILE\""
)
(config_path,) = config_files
- if not os.path.exists(config_path):
+ if not cls.path_exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
@@ -261,33 +280,33 @@ class Config(object):
"Must specify a server_name to a generate config for."
" Pass -H server.name."
)
- if not os.path.exists(config_dir_path):
+ if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
- with open(config_path, "wb") as config_file:
- config_bytes, config = obj.generate_config(
+ with open(config_path, "w") as config_file:
+ config_str, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
)
obj.invoke_all("generate_files", config)
- config_file.write(config_bytes)
- print (
+ config_file.write(config_str)
+ print((
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
- ) % (config_path, server_name)
- print (
+ ) % (config_path, server_name))
+ print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
- print (
+ print((
"Config file %r already exists. Generating any missing key"
" files."
- ) % (config_path,)
+ ) % (config_path,))
generate_keys = True
parser = argparse.ArgumentParser(
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 20ba33226a..403d96ba76 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
-
from synapse.api.constants import EventTypes
+from ._base import Config
+
class ApiConfig(Config):
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 82c50b8240..3b161d708a 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ConfigError
+import logging
+
+from six import string_types
+from six.moves.urllib import parse as urlparse
+
+import yaml
+from netaddr import IPSet
from synapse.appservice import ApplicationService
from synapse.types import UserID
-import urllib
-import yaml
-import logging
+from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
@@ -89,21 +93,21 @@ def _load_appservice(hostname, as_info, config_filename):
"id", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
- if not isinstance(as_info.get(field), basestring):
+ if not isinstance(as_info.get(field), string_types):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (not isinstance(as_info.get("url"), basestring) and
+ if (not isinstance(as_info.get("url"), string_types) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
- if urllib.quote(localpart) != localpart:
+ if urlparse.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
@@ -128,7 +132,7 @@ def _load_appservice(hostname, as_info, config_filename):
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
- if not isinstance(regex_obj.get("regex"), basestring):
+ if not isinstance(regex_obj.get("regex"), string_types):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
@@ -152,13 +156,22 @@ def _load_appservice(hostname, as_info, config_filename):
" will not receive events or queries.",
config_filename,
)
+
+ ip_range_whitelist = None
+ if as_info.get('ip_range_whitelist'):
+ ip_range_whitelist = IPSet(
+ as_info.get('ip_range_whitelist')
+ )
+
return ApplicationService(
token=as_info["as_token"],
+ hostname=hostname,
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
protocols=protocols,
- rate_limited=rate_limited
+ rate_limited=rate_limited,
+ ip_range_whitelist=ip_range_whitelist,
)
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 938f6f25f8..8109e5f95e 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -41,7 +41,7 @@ class CasConfig(Config):
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
- # service_url: "https://homesever.domain.com:8448"
+ # service_url: "https://homeserver.domain.com:8448"
# #required_attributes:
# # name: value
"""
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
new file mode 100644
index 0000000000..e22c731aad
--- /dev/null
+++ b/synapse/config/consent_config.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+DEFAULT_CONFIG = """\
+# User Consent configuration
+#
+# for detailed instructions, see
+# https://github.com/matrix-org/synapse/blob/master/docs/consent_tracking.md
+#
+# Parts of this section are required if enabling the 'consent' resource under
+# 'listeners', in particular 'template_dir' and 'version'.
+#
+# 'template_dir' gives the location of the templates for the HTML forms.
+# This directory should contain one subdirectory per language (eg, 'en', 'fr'),
+# and each language directory should contain the policy document (named as
+# '<version>.html') and a success page (success.html).
+#
+# 'version' specifies the 'current' version of the policy document. It defines
+# the version to be served by the consent resource if there is no 'v'
+# parameter.
+#
+# 'server_notice_content', if enabled, will send a user a "Server Notice"
+# asking them to consent to the privacy policy. The 'server_notices' section
+# must also be configured for this to work. Notices will *not* be sent to
+# guest users unless 'send_server_notice_to_guests' is set to true.
+#
+# 'block_events_error', if set, will block any attempts to send events
+# until the user consents to the privacy policy. The value of the setting is
+# used as the text of the error.
+#
+# user_consent:
+# template_dir: res/templates/privacy
+# version: 1.0
+# server_notice_content:
+# msgtype: m.text
+# body: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+# send_server_notice_to_guests: True
+# block_events_error: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+#
+"""
+
+
+class ConsentConfig(Config):
+ def __init__(self):
+ super(ConsentConfig, self).__init__()
+
+ self.user_consent_version = None
+ self.user_consent_template_dir = None
+ self.user_consent_server_notice_content = None
+ self.user_consent_server_notice_to_guests = False
+ self.block_events_without_consent_error = None
+
+ def read_config(self, config):
+ consent_config = config.get("user_consent")
+ if consent_config is None:
+ return
+ self.user_consent_version = str(consent_config["version"])
+ self.user_consent_template_dir = consent_config["template_dir"]
+ self.user_consent_server_notice_content = consent_config.get(
+ "server_notice_content",
+ )
+ self.block_events_without_consent_error = consent_config.get(
+ "block_events_error",
+ )
+ self.user_consent_server_notice_to_guests = bool(consent_config.get(
+ "send_server_notice_to_guests", False,
+ ))
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
new file mode 100644
index 0000000000..997fa2881f
--- /dev/null
+++ b/synapse/config/groups.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class GroupsConfig(Config):
+ def read_config(self, config):
+ self.enable_group_creation = config.get("enable_group_creation", False)
+ self.group_creation_prefix = config.get("group_creation_prefix", "")
+
+ def default_config(self, **kwargs):
+ return """\
+ # Whether to allow non server admins to create groups on this server
+ enable_group_creation: false
+
+ # If enabled, non server admins can only create groups with local parts
+ # starting with this prefix
+ # group_creation_prefix: "unofficial/"
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b22cacf8dc..2fd9c48abf 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,28 +13,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from .tls import TlsConfig
-from .server import ServerConfig
-from .logger import LoggingConfig
-from .database import DatabaseConfig
-from .ratelimiting import RatelimitConfig
-from .repository import ContentRepositoryConfig
-from .captcha import CaptchaConfig
-from .voip import VoipConfig
-from .registration import RegistrationConfig
-from .metrics import MetricsConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
-from .key import KeyConfig
-from .saml2 import SAML2Config
+from .captcha import CaptchaConfig
from .cas import CasConfig
-from .password import PasswordConfig
+from .consent_config import ConsentConfig
+from .database import DatabaseConfig
+from .emailconfig import EmailConfig
+from .groups import GroupsConfig
from .jwt import JWTConfig
+from .key import KeyConfig
+from .logger import LoggingConfig
+from .metrics import MetricsConfig
+from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
-from .emailconfig import EmailConfig
-from .workers import WorkerConfig
from .push import PushConfig
+from .ratelimiting import RatelimitConfig
+from .registration import RegistrationConfig
+from .repository import ContentRepositoryConfig
+from .saml2 import SAML2Config
+from .server import ServerConfig
+from .server_notices_config import ServerNoticesConfig
+from .spam_checker import SpamCheckerConfig
+from .tls import TlsConfig
+from .user_directory import UserDirectoryConfig
+from .voip import VoipConfig
+from .workers import WorkerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -41,12 +46,16 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
+ WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+ SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
+ ConsentConfig,
+ ServerNoticesConfig,
+ ):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
- HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
+ HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
)
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 47f145c589..51e7f7e003 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -15,7 +15,6 @@
from ._base import Config, ConfigError
-
MISSING_JWT = (
"""Missing jwt library. This is required for jwt login.
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 6ee643793e..279c47bb48 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -13,21 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ConfigError
+import hashlib
+import logging
+import os
-from synapse.util.stringutils import random_string
from signedjson.key import (
- generate_signing_key, is_signing_algorithm_supported,
- decode_signing_key_base64, decode_verify_key_bytes,
- read_signing_keys, write_signing_keys, NACL_ED25519
+ NACL_ED25519,
+ decode_signing_key_base64,
+ decode_verify_key_bytes,
+ generate_signing_key,
+ is_signing_algorithm_supported,
+ read_signing_keys,
+ write_signing_keys,
)
from unpaddedbase64 import decode_base64
-from synapse.util.stringutils import random_string_with_symbols
-import os
-import hashlib
-import logging
+from synapse.util.stringutils import random_string, random_string_with_symbols
+from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
@@ -59,14 +62,20 @@ class KeyConfig(Config):
self.expire_access_token = config.get("expire_access_token", False)
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ self.form_secret = config.get("form_secret", None)
+
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
+ form_secret = '"%s"' % random_string_with_symbols(50)
else:
macaroon_secret_key = None
+ form_secret = 'null'
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
@@ -74,6 +83,10 @@ class KeyConfig(Config):
# Used to enable access token expiration.
expire_access_token: False
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ form_secret: %(form_secret)s
+
## Signing Keys ##
# Path to the signing key to sign messages with
@@ -118,10 +131,9 @@ class KeyConfig(Config):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return read_signing_keys(signing_keys.splitlines(True))
- except Exception:
+ except Exception as e:
raise ConfigError(
- "Error reading signing_key."
- " Try running again with --generate-config"
+ "Error reading signing_key: %s" % (str(e))
)
def read_old_signing_keys(self, old_signing_keys):
@@ -141,7 +153,8 @@ class KeyConfig(Config):
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
- if not os.path.exists(signing_key_path):
+
+ if not self.path_exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 2dbeafa9dd..a87b11a1df 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -12,43 +12,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from ._base import Config
-from synapse.util.logcontext import LoggingContextFilter
-from twisted.logger import globalLogBeginner, STDLibLogObserver
import logging
import logging.config
-import yaml
-from string import Template
import os
import signal
+import sys
+from string import Template
+import yaml
+
+from twisted.logger import STDLibLogObserver, globalLogBeginner
+
+import synapse
+from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.versionstring import get_version_string
+
+from ._base import Config
DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
- precise:
- format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
-- %(message)s'
+ precise:
+ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
+%(request)s - %(message)s'
filters:
- context:
- (): synapse.util.logcontext.LoggingContextFilter
- request: ""
+ context:
+ (): synapse.util.logcontext.LoggingContextFilter
+ request: ""
handlers:
- file:
- class: logging.handlers.RotatingFileHandler
- formatter: precise
- filename: ${log_file}
- maxBytes: 104857600
- backupCount: 10
- filters: [context]
- console:
- class: logging.StreamHandler
- formatter: precise
- filters: [context]
+ file:
+ class: logging.handlers.RotatingFileHandler
+ formatter: precise
+ filename: ${log_file}
+ maxBytes: 104857600
+ backupCount: 10
+ filters: [context]
+ console:
+ class: logging.StreamHandler
+ formatter: precise
+ filters: [context]
loggers:
synapse:
@@ -74,17 +79,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
- log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
- # Logging verbosity level. Ignored if log_config is specified.
- verbose: 0
-
- # File to write logging to. Ignored if log_config is specified.
- log_file: "%(log_file)s"
-
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
@@ -123,9 +121,10 @@ class LoggingConfig(Config):
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
- with open(log_config, "wb") as log_config_file:
+ log_file = self.abspath("homeserver.log")
+ with open(log_config, "w") as log_config_file:
log_config_file.write(
- DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+ DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
@@ -148,8 +147,11 @@ def setup_logging(config, use_worker_options=False):
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
- if log_config is None:
+ if log_config is None:
+ # We don't have a logfile, so fall back to the 'verbosity' param from
+ # the config or cmdline. (Note that we generate a log config for new
+ # installs, so this will be an unusual case)
level = logging.INFO
level_for_storage = logging.INFO
if config.verbosity:
@@ -157,11 +159,10 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1:
level_for_storage = logging.DEBUG
- # FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('')
logger.setLevel(level)
- logging.getLogger('synapse.storage').setLevel(level_for_storage)
+ logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:
@@ -176,6 +177,10 @@ def setup_logging(config, use_worker_options=False):
logger.info("Opened new log file due to SIGHUP")
else:
handler = logging.StreamHandler()
+
+ def sighup(signum, stack):
+ pass
+
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
@@ -202,6 +207,15 @@ def setup_logging(config, use_worker_options=False):
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
+ # make sure that the first thing we log is a thing we can grep backwards
+ # for
+ logging.warn("***** STARTING SERVER *****")
+ logging.warn(
+ "Server %s version %s",
+ sys.argv[0], get_version_string(synapse),
+ )
+ logging.info("Server hostname: %s", config.server_name)
+
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
# see: https://twistedmatrix.com/trac/ticket/8164
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83762d089a..f4066abc28 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -13,44 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ConfigError
+from synapse.util.module_loader import load_module
-import importlib
+from ._base import Config
+
+LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
+ providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
- self.ldap_enabled = ldap_config.get("enabled", False)
- if self.ldap_enabled:
- from ldap_auth_provider import LdapAuthProvider
- parsed_config = LdapAuthProvider.parse_config(ldap_config)
- self.password_providers.append((LdapAuthProvider, parsed_config))
+ if ldap_config.get("enabled", False):
+ providers.append({
+ 'module': LDAP_PROVIDER,
+ 'config': ldap_config,
+ })
- providers = config.get("password_providers", [])
+ providers.extend(config.get("password_providers", []))
for provider in providers:
+ mod_name = provider['module']
+
# This is for backwards compat when the ldap auth provider resided
# in this package.
- if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
- from ldap_auth_provider import LdapAuthProvider
- provider_class = LdapAuthProvider
- else:
- # We need to import the module, and then pick the class out of
- # that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
- module = importlib.import_module(module)
- provider_class = getattr(module, clz)
+ if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
+ mod_name = LDAP_PROVIDER
+
+ (provider_class, provider_config) = load_module({
+ "module": mod_name,
+ "config": provider['config'],
+ })
- try:
- provider_config = provider_class.parse_config(provider["config"])
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 9c68318b40..b7e0d46afa 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,28 +19,43 @@ from ._base import Config
class PushConfig(Config):
def read_config(self, config):
- self.push_redact_content = False
+ push_config = config.get("push", {})
+ self.push_include_content = push_config.get("include_content", True)
+ # There was a a 'redact_content' setting but mistakenly read from the
+ # 'email'section'. Check for the flag in the 'push' section, and log,
+ # but do not honour it to avoid nasty surprises when people upgrade.
+ if push_config.get("redact_content") is not None:
+ print(
+ "The push.redact_content content option has never worked. "
+ "Please set push.include_content if you want this behaviour"
+ )
+
+ # Now check for the one in the 'email' section and honour it,
+ # with a warning.
push_config = config.get("email", {})
- self.push_redact_content = push_config.get("redact_content", False)
+ redact_content = push_config.get("redact_content")
+ if redact_content is not None:
+ print(
+ "The 'email.redact_content' option is deprecated: "
+ "please set push.include_content instead"
+ )
+ self.push_include_content = not redact_content
def default_config(self, config_dir_path, server_name, **kwargs):
return """
- # Control how push messages are sent to google/apple to notifications.
- # Normally every message said in a room with one or more people using
- # mobile devices will be posted to a push server hosted by matrix.org
- # which is registered with google and apple in order to allow push
- # notifications to be sent to these mobile devices.
- #
- # Setting redact_content to true will make the push messages contain no
- # message content which will provide increased privacy. This is a
- # temporary solution pending improvements to Android and iPhone apps
- # to get content from the app rather than the notification.
- #
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+
# For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
- # redact_content: false
+ # include_content: true
"""
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index f7e03c4cde..0fb964eb67 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from distutils.util import strtobool
from synapse.util.stringutils import random_string_with_symbols
-from distutils.util import strtobool
+from ._base import Config
class RegistrationConfig(Config):
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
+ self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+ self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -41,6 +43,8 @@ class RegistrationConfig(Config):
self.allow_guest_access and config.get("invite_3pid_guest", False)
)
+ self.auto_join_rooms = config.get("auto_join_rooms", [])
+
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -50,13 +54,32 @@ class RegistrationConfig(Config):
# Enable registration for new users.
enable_registration: False
+ # The user must provide all of the below types of 3PID when registering.
+ #
+ # registrations_require_3pid:
+ # - email
+ # - msisdn
+
+ # Mandate that users are only allowed to associate certain formats of
+ # 3PIDs with accounts on this server.
+ #
+ # allowed_local_3pids:
+ # - medium: email
+ # pattern: ".*@matrix\\.org"
+ # - medium: email
+ # pattern: ".*@vector\\.im"
+ # - medium: msisdn
+ # pattern: "\\+44"
+
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
- # The default number of rounds is 12.
+ # The default number is 12 (which equates to 2^12 rounds).
+ # N.B. that increasing this will exponentially increase the time required
+ # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
@@ -70,6 +93,11 @@ class RegistrationConfig(Config):
- matrix.org
- vector.im
- riot.im
+
+ # Users who register on this homeserver will automatically be joined
+ # to these rooms
+ #auto_join_rooms:
+ # - "#example:example.com"
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 2c6f57168e..fc909c1fac 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ConfigError
from collections import namedtuple
+from synapse.util.module_loader import load_module
+
+from ._base import Config, ConfigError
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
+MediaStorageProviderConfig = namedtuple(
+ "MediaStorageProviderConfig", (
+ "store_local", # Whether to store newly uploaded local files
+ "store_remote", # Whether to store newly downloaded remote files
+ "store_synchronous", # Whether to wait for successful storage for local uploads
+ ),
+)
+
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
@@ -70,7 +80,64 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
+
self.media_store_path = self.ensure_directory(config["media_store_path"])
+
+ backup_media_store_path = config.get("backup_media_store_path")
+
+ synchronous_backup_media_store = config.get(
+ "synchronous_backup_media_store", False
+ )
+
+ storage_providers = config.get("media_storage_providers", [])
+
+ if backup_media_store_path:
+ if storage_providers:
+ raise ConfigError(
+ "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+ )
+
+ storage_providers = [{
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {
+ "directory": backup_media_store_path,
+ }
+ }]
+
+ # This is a list of config that can be used to create the storage
+ # providers. The entries are tuples of (Class, class_config,
+ # MediaStorageProviderConfig), where Class is the class of the provider,
+ # the class_config the config to pass to it, and
+ # MediaStorageProviderConfig are options for StorageProviderWrapper.
+ #
+ # We don't create the storage providers here as not all workers need
+ # them to be started.
+ self.media_storage_providers = []
+
+ for provider_config in storage_providers:
+ # We special case the module "file_system" so as not to need to
+ # expose FileStorageProviderBackend
+ if provider_config["module"] == "file_system":
+ provider_config["module"] = (
+ "synapse.rest.media.v1.storage_provider"
+ ".FileStorageProviderBackend"
+ )
+
+ provider_class, parsed_config = load_module(provider_config)
+
+ wrapper_config = MediaStorageProviderConfig(
+ provider_config.get("store_local", False),
+ provider_config.get("store_remote", False),
+ provider_config.get("store_synchronous", False),
+ )
+
+ self.media_storage_providers.append(
+ (provider_class, parsed_config, wrapper_config,)
+ )
+
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -115,6 +182,20 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
+ # Media storage providers allow media to be stored in different
+ # locations.
+ # media_storage_providers:
+ # - module: file_system
+ # # Whether to write new local files.
+ # store_local: false
+ # # Whether to write new remote media
+ # store_remote: false
+ # # Whether to block upload requests waiting for write to this
+ # # provider to complete
+ # store_synchronous: false
+ # config:
+ # directory: /mnt/some/other/directory
+
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
@@ -169,6 +250,9 @@ class ContentRepositoryConfig(Config):
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
+ # - '::1/128'
+ # - 'fe80::/64'
+ # - 'fc00::/7'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 28b4e5f50c..18102656b0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from synapse.http.endpoint import parse_and_validate_server_name
+
from ._base import Config, ConfigError
+logger = logging.Logger(__name__)
+
class ServerConfig(Config):
def read_config(self, config):
self.server_name = config["server_name"]
+
+ try:
+ parse_and_validate_server_name(self.server_name)
+ except ValueError as e:
+ raise ConfigError(str(e))
+
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None)
@@ -29,6 +42,7 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
+ self.cpu_affinity = config.get("cpu_affinity")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
@@ -39,8 +53,31 @@ class ServerConfig(Config):
# false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True)
+ # whether to enable the media repository endpoints. This should be set
+ # to false if the media repository is running as a separate endpoint;
+ # doing so ensures that we will not run cache cleanup jobs on the
+ # master, potentially causing inconsistency.
+ self.enable_media_repo = config.get("enable_media_repo", True)
+
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+ # Whether we should block invites sent to users on this server
+ # (other than those sent by local server admins)
+ self.block_non_admin_invites = config.get(
+ "block_non_admin_invites", False,
+ )
+
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None
+ federation_domain_whitelist = config.get(
+ "federation_domain_whitelist", None
+ )
+ # turn the whitelist into a hash for speed of lookup
+ if federation_domain_whitelist is not None:
+ self.federation_domain_whitelist = {}
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -113,6 +150,12 @@ class ServerConfig(Config):
metrics_port = config.get("metrics_port")
if metrics_port:
+ logger.warn(
+ ("The metrics_port configuration option is deprecated in Synapse 0.31 "
+ "in favour of a listener. Please see "
+ "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
+ " on how to configure the new listener."))
+
self.listeners.append({
"port": metrics_port,
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
@@ -127,8 +170,8 @@ class ServerConfig(Config):
})
def default_config(self, server_name, **kwargs):
- if ":" in server_name:
- bind_port = int(server_name.split(":")[1])
+ _, bind_port = parse_and_validate_server_name(server_name)
+ if bind_port is not None:
unsecure_port = bind_port - 400
else:
bind_port = 8448
@@ -147,6 +190,27 @@ class ServerConfig(Config):
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
+ # CPU affinity mask. Setting this restricts the CPUs on which the
+ # process will be scheduled. It is represented as a bitmask, with the
+ # lowest order bit corresponding to the first logical CPU and the
+ # highest order bit corresponding to the last logical CPU. Not all CPUs
+ # may exist on a given system but a mask may specify more CPUs than are
+ # present.
+ #
+ # For example:
+ # 0x00000001 is processor #0,
+ # 0x00000003 is processors #0 and #1,
+ # 0xFFFFFFFF is all processors (#0 through #31).
+ #
+ # Pinning a Python process to a single CPU is desirable, because Python
+ # is inherently single-threaded due to the GIL, and can suffer a
+ # 30-40%% slowdown due to cache blow-out and thread context switching
+ # if the scheduler happens to schedule the underlying threads across
+ # different cores. See
+ # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
+ #
+ # cpu_affinity: 0xFFFFFFFF
+
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
@@ -171,6 +235,21 @@ class ServerConfig(Config):
# and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000
+ # Whether room invites to users on this server should be blocked
+ # (except those sent by local server admins). The default is False.
+ # block_non_admin_invites: True
+
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ # federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
@@ -181,13 +260,12 @@ class ServerConfig(Config):
port: %(bind_port)s
# Local addresses to listen on.
- # This will listen on all IPv4 addresses by default.
+ # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
+ # addresses by default. For most other OSes, this will only listen
+ # on IPv6.
bind_addresses:
+ - '::'
- '0.0.0.0'
- # Uncomment to listen on all IPv6 interfaces
- # N.B: On at least Linux this will also listen on all IPv4
- # addresses, so you will need to comment out the line above.
- # - '::'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
@@ -214,11 +292,18 @@ class ServerConfig(Config):
- names: [federation] # Federation APIs
compress: false
+ # optional list of additional endpoints which can be loaded via
+ # dynamic modules
+ # additional_resources:
+ # "/_matrix/my/custom/endpoint":
+ # module: my_module.CustomRequestHandler
+ # config: {}
+
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
- bind_addresses: ['0.0.0.0']
+ bind_addresses: ['::', '0.0.0.0']
type: http
x_forwarded: false
@@ -232,7 +317,7 @@ class ServerConfig(Config):
# Turn on the twisted ssh manhole service on localhost on the given
# port.
# - port: 9000
- # bind_address: 127.0.0.1
+ # bind_addresses: ['::1', '127.0.0.1']
# type: manhole
""" % locals()
@@ -270,7 +355,7 @@ def read_gc_thresholds(thresholds):
return (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
- except:
+ except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
new file mode 100644
index 0000000000..3c39850ac6
--- /dev/null
+++ b/synapse/config/server_notices_config.py
@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.types import UserID
+
+from ._base import Config
+
+DEFAULT_CONFIG = """\
+# Server Notices room configuration
+#
+# Uncomment this section to enable a room which can be used to send notices
+# from the server to users. It is a special room which cannot be left; notices
+# come from a special "notices" user id.
+#
+# If you uncomment this section, you *must* define the system_mxid_localpart
+# setting, which defines the id of the user which will be used to send the
+# notices.
+#
+# It's also possible to override the room name, the display name of the
+# "notices" user, and the avatar for the user.
+#
+# server_notices:
+# system_mxid_localpart: notices
+# system_mxid_display_name: "Server Notices"
+# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
+# room_name: "Server Notices"
+"""
+
+
+class ServerNoticesConfig(Config):
+ """Configuration for the server notices room.
+
+ Attributes:
+ server_notices_mxid (str|None):
+ The MXID to use for server notices.
+ None if server notices are not enabled.
+
+ server_notices_mxid_display_name (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_mxid_avatar_url (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_room_name (str|None):
+ The name to use for the server notices room.
+ None if server notices are not enabled.
+ """
+ def __init__(self):
+ super(ServerNoticesConfig, self).__init__()
+ self.server_notices_mxid = None
+ self.server_notices_mxid_display_name = None
+ self.server_notices_mxid_avatar_url = None
+ self.server_notices_room_name = None
+
+ def read_config(self, config):
+ c = config.get("server_notices")
+ if c is None:
+ return
+
+ mxid_localpart = c['system_mxid_localpart']
+ self.server_notices_mxid = UserID(
+ mxid_localpart, self.server_name,
+ ).to_string()
+ self.server_notices_mxid_display_name = c.get(
+ 'system_mxid_display_name', None,
+ )
+ self.server_notices_mxid_avatar_url = c.get(
+ 'system_mxid_avatar_url', None,
+ )
+ # todo: i18n
+ self.server_notices_room_name = c.get('room_name', "Server Notices")
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
new file mode 100644
index 0000000000..3fec42bdb0
--- /dev/null
+++ b/synapse/config/spam_checker.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class SpamCheckerConfig(Config):
+ def read_config(self, config):
+ self.spam_checker = None
+
+ provider = config.get("spam_checker", None)
+ if provider is not None:
+ self.spam_checker = load_module(provider)
+
+ def default_config(self, **kwargs):
+ return """\
+ # spam_checker:
+ # module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e081840a83..fef1ea99cb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
-
-from OpenSSL import crypto
-import subprocess
import os
-
+import subprocess
from hashlib import sha256
+
from unpaddedbase64 import encode_base64
+from OpenSSL import crypto
+
+from ._base import Config
+
GENERATE_DH_PARAMS = False
@@ -96,7 +97,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
- # to the list. So if federation traffic is handle directly by synapse
+ # to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
@@ -109,6 +110,12 @@ class TlsConfig(Config):
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
+ #
+ # You can calculate a fingerprint from a given TLS listener via:
+ # openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
+ # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
+ # or by checking matrix.org/federationtester/api/report?server_name=$host
+ #
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()
@@ -126,8 +133,8 @@ class TlsConfig(Config):
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
- if not os.path.exists(tls_private_key_path):
- with open(tls_private_key_path, "w") as private_key_file:
+ if not self.path_exists(tls_private_key_path):
+ with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -141,8 +148,8 @@ class TlsConfig(Config):
crypto.FILETYPE_PEM, private_key_pem
)
- if not os.path.exists(tls_certificate_path):
- with open(tls_certificate_path, "w") as certificate_file:
+ if not self.path_exists(tls_certificate_path):
+ with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
@@ -159,7 +166,7 @@ class TlsConfig(Config):
certificate_file.write(cert_pem)
- if not os.path.exists(tls_dh_params_path):
+ if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
new file mode 100644
index 0000000000..38e8947843
--- /dev/null
+++ b/synapse/config/user_directory.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class UserDirectoryConfig(Config):
+ """User Directory Configuration
+ Configuration for the behaviour of the /user_directory API
+ """
+
+ def read_config(self, config):
+ self.user_directory_search_all_users = False
+ user_directory_config = config.get("user_directory", None)
+ if user_directory_config:
+ self.user_directory_search_all_users = (
+ user_directory_config.get("search_all_users", False)
+ )
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # User Directory configuration
+ #
+ # 'search_all_users' defines whether to search all users visible to your HS
+ # when searching the user directory, rather than limiting to users visible
+ # in public rooms. Defaults to false. If you set it True, you'll have to run
+ # UPDATE user_directory_stream_pos SET stream_id = NULL;
+ # on your database to tell it to rebuild the user_directory search indexes.
+ #
+ #user_directory:
+ # search_all_users: false
+ """
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 3a4e16fa96..d07bd24ffd 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -30,10 +30,10 @@ class VoipConfig(Config):
## Turn ##
# The public URIs of the TURN server to give to clients
- turn_uris: []
+ #turn_uris: []
# The shared secret used to compute passwords for the TURN server
- turn_shared_secret: "YOUR_SHARED_SECRET"
+ #turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index ea48d931a1..80baf0ce0e 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -23,15 +23,31 @@ class WorkerConfig(Config):
def read_config(self, config):
self.worker_app = config.get("worker_app")
+
+ # Canonicalise worker_app so that master always has None
+ if self.worker_app == "synapse.app.homeserver":
+ self.worker_app = None
+
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
+
+ # The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None)
+
+ # The port on the main synapse for TCP replication
self.worker_replication_port = config.get("worker_replication_port", None)
+
+ # The port on the main synapse for HTTP replication endpoint
+ self.worker_replication_http_port = config.get("worker_replication_http_port")
+
self.worker_name = config.get("worker_name", self.worker_app)
+ self.worker_main_http_uri = config.get("worker_main_http_uri", None)
+ self.worker_cpu_affinity = config.get("worker_cpu_affinity")
+
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index aad4752fe7..a1e1d0d33a 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import ssl
-from OpenSSL import SSL
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName
-
import logging
+from OpenSSL import SSL, crypto
+from twisted.internet import ssl
+from twisted.internet._sslverify import _defaultCurveName
+
logger = logging.getLogger(__name__)
@@ -32,9 +32,10 @@ class ServerContextFactory(ssl.ContextFactory):
@staticmethod
def configure_context(context, config):
try:
- _ecCurve = _OpenSSLECCurve(_defaultCurveName)
- _ecCurve.addECKeyToContext(context)
- except:
+ _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
+ context.set_tmp_ecdh(_ecCurve)
+
+ except Exception:
logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate_chain_file(config.tls_certificate_file)
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index ec7711ba7d..8774b28967 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -15,15 +15,15 @@
# limitations under the License.
-from synapse.api.errors import SynapseError, Codes
-from synapse.events.utils import prune_event
+import hashlib
+import logging
from canonicaljson import encode_canonical_json
-from unpaddedbase64 import encode_base64, decode_base64
from signedjson.sign import sign_json
+from unpaddedbase64 import decode_base64, encode_base64
-import hashlib
-import logging
+from synapse.api.errors import Codes, SynapseError
+from synapse.events.utils import prune_event
logger = logging.getLogger(__name__)
@@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
- if name not in event.hashes:
+
+ # some malformed events lack a 'hashes'. Protect against it being missing
+ # or a weird type by basically treating it the same as an unhashed event.
+ hashes = event.get("hashes")
+ if not isinstance(hashes, dict):
+ raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+
+ if name not in hashes:
raise SynapseError(
400,
"Algorithm %s not in hashes %s" % (
- name, list(event.hashes),
+ name, list(hashes),
),
Codes.UNAUTHORIZED,
)
- message_hash_base64 = event.hashes[name]
+ message_hash_base64 = hashes[name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
- except:
+ except Exception:
raise SynapseError(
400,
"Invalid base64: %s" % (message_hash_base64,),
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c2bd64d6c2..668b4f517d 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -13,17 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from canonicaljson import json
-from twisted.web.http import HTTPClient
-from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
-from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import (
- preserve_context_over_fn, preserve_context_over_deferred
-)
-import simplejson as json
-import logging
+from twisted.internet.protocol import Factory
+from twisted.web.http import HTTPClient
+from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.util import logcontext
logger = logging.getLogger(__name__)
@@ -43,14 +42,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
- protocol = yield preserve_context_over_fn(
- endpoint.connect, factory
- )
- server_response, server_certificate = yield preserve_context_over_deferred(
- protocol.remote_key
- )
- defer.returnValue((server_response, server_certificate))
- return
+ with logcontext.PreserveLoggingContext():
+ protocol = yield endpoint.connect(factory)
+ server_response, server_certificate = yield protocol.remote_key
+ defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 1bb27edc0f..e95b9fb43e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,33 +14,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.crypto.keyclient import fetch_server_key
-from synapse.api.errors import SynapseError, Codes
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import (
- preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
- preserve_fn
-)
-from synapse.util.metrics import Measure
-
-from twisted.internet import defer
+import hashlib
+import logging
+import urllib
+from collections import namedtuple
-from signedjson.sign import (
- verify_signed_json, signature_ids, sign_json, encode_canonical_json
-)
from signedjson.key import (
- is_signing_algorithm_supported, decode_verify_key_bytes
+ decode_verify_key_bytes,
+ encode_verify_key_base64,
+ is_signing_algorithm_supported,
+)
+from signedjson.sign import (
+ SignatureVerifyException,
+ encode_canonical_json,
+ sign_json,
+ signature_ids,
+ verify_signed_json,
)
from unpaddedbase64 import decode_base64, encode_base64
from OpenSSL import crypto
+from twisted.internet import defer
-from collections import namedtuple
-import urllib
-import hashlib
-import logging
-
+from synapse.api.errors import Codes, SynapseError
+from synapse.crypto.keyclient import fetch_server_key
+from synapse.util import logcontext, unwrapFirstError
+from synapse.util.logcontext import (
+ PreserveLoggingContext,
+ preserve_fn,
+ run_in_background,
+)
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -55,9 +60,10 @@ Attributes:
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
- deferred(twisted.internet.defer.Deferred):
+ deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
"""
@@ -74,23 +80,32 @@ class Keyring(object):
self.perspective_servers = self.config.perspectives
self.hs = hs
+ # map from server name to Deferred. Has an entry for each server with
+ # an ongoing key download; the Deferred completes once the download
+ # completes.
+ #
+ # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
- return self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ return logcontext.make_deferred_yieldable(
+ self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+ )
def verify_json_objects_for_server(self, server_and_json):
- """Bulk verfies signatures of json objects, bulk fetching keys as
+ """Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
- list of deferreds indicating success or failure to verify each
- json object's signature for the given server_name.
+ List<Deferred>: for each input pair, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
"""
verify_requests = []
@@ -117,73 +132,60 @@ class Keyring(object):
verify_requests.append(verify_request)
- @defer.inlineCallbacks
- def handle_key_deferred(verify_request):
- server_name = verify_request.server_name
- try:
- _, key_id, verify_key = yield verify_request.deferred
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ run_in_background(self._start_key_lookups, verify_requests)
- json_object = verify_request.json_object
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ handle = preserve_fn(_handle_key_deferred)
+ return [
+ handle(rq) for rq in verify_requests
+ ]
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ @defer.inlineCallbacks
+ def _start_key_lookups(self, verify_requests):
+ """Sets off the key fetches for each verify request
- server_to_deferred = {
- server_name: defer.Deferred()
- for server_name, _ in server_and_json
- }
+ Once each fetch completes, verify_request.deferred will be resolved.
- with PreserveLoggingContext():
+ Args:
+ verify_requests (List[VerifyKeyRequest]):
+ """
+
+ try:
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
+ server_to_deferred = {
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
+ }
# We want to wait for any previous lookups to complete before
# proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
server_to_deferred,
)
# Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(verify_requests)
- )
+ self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
server_to_request_ids = {}
- def remove_deferreds(res, server_name, verify_request):
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
@@ -193,17 +195,11 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
- deferred.addBoth(remove_deferreds, server_name, verify_request)
-
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- return [
- preserve_context_over_fn(handle_key_deferred, verify_request)
- for verify_request in verify_requests
- ]
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
+ except Exception:
+ logger.exception("Error starting key lookups")
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -212,7 +208,13 @@ class Keyring(object):
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
- resolved once we've finished looking up keys for that server
+ resolved once we've finished looking up keys for that server.
+ The Deferreds should be regular twisted ones which call their
+ callbacks with no logcontext.
+
+ Returns: a Deferred which resolves once all key lookups for the given
+ servers have completed. Follows the synapse rules of logcontext
+ preservation.
"""
while True:
wait_on = [
@@ -226,17 +228,15 @@ class Keyring(object):
else:
break
- for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(preserve_context_over_deferred(deferred))
- self.key_downloads[server_name] = d
-
- def rm(r, server_name):
- self.key_downloads.pop(server_name, None)
- return r
+ def rm(r, server_name_):
+ self.key_downloads.pop(server_name_, None)
+ return r
- d.addBoth(rm, server_name)
+ for server_name, deferred in server_to_deferred.items():
+ self.key_downloads[server_name] = deferred
+ deferred.addBoth(rm, server_name)
- def get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with
@@ -305,21 +305,23 @@ class Keyring(object):
if not missing_keys:
break
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ with PreserveLoggingContext():
+ for verify_request in requests_missing_keys:
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ with PreserveLoggingContext():
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
- do_iterations().addErrback(on_err)
+ run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
@@ -333,15 +335,16 @@ class Keyring(object):
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
server_name -> key_id -> VerifyKey
"""
- res = yield preserve_context_over_deferred(defer.gatherResults(
+ res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.get_server_verify_keys)(
- server_name, key_ids
+ run_in_background(
+ self.store.get_server_verify_keys,
+ server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(dict(res))
@@ -358,17 +361,17 @@ class Keyring(object):
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
defer.returnValue({})
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(p_name, p_keys)
+ run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
union_of_keys = {}
for result in results:
@@ -390,7 +393,7 @@ class Keyring(object):
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
if not keys:
@@ -402,13 +405,13 @@ class Keyring(object):
defer.returnValue(keys)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(server_name, key_ids)
+ run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
merged = {}
for result in results:
@@ -485,9 +488,10 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -495,7 +499,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -543,9 +547,10 @@ class Keyring(object):
keys.update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -553,7 +558,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -619,9 +624,10 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_keys_json)(
+ run_in_background(
+ self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -632,7 +638,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
results[server_name] = response_keys
@@ -710,7 +716,6 @@ class Keyring(object):
defer.returnValue(verify_keys)
- @defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
@@ -721,12 +726,74 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
- yield preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_verify_key)(
+ run_in_background(
+ self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+ """Waits for the key to become available, and then performs a verification
+
+ Args:
+ verify_request (VerifyKeyRequest):
+
+ Returns:
+ Deferred[None]
+
+ Raises:
+ SynapseError if there was a problem performing the verification
+ """
+ server_name = verify_request.server_name
+ try:
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = verify_request.json_object
+
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except SignatureVerifyException as e:
+ logger.debug(
+ "Error verifying signature for %s:%s:%s with key %s: %s",
+ server_name, verify_key.alg, verify_key.version,
+ encode_verify_key_base64(verify_key),
+ str(e),
+ )
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s: %s" % (
+ server_name, verify_key.alg, verify_key.version, str(e),
+ ),
+ Codes.UNAUTHORIZED,
+ )
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4096c606f1..b32f64e729 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -17,11 +17,11 @@ import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json, SignatureVerifyException
+from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
-from synapse.api.constants import EventTypes, Membership, JoinRules
-from synapse.api.errors import AuthError, SynapseError, EventSizeError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -34,9 +34,11 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
+ Raises:
+ AuthError if the checks fail
Returns:
- True if the auth checks pass.
+ if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
@@ -71,9 +73,10 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
- return True
+ return
if event.type == EventTypes.Create:
+ sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
@@ -81,7 +84,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
"Creation event's room_id domain does not match sender's"
)
# FIXME
- return True
+ logger.debug("Allowing! %s", event)
+ return
creation_event = auth_events.get((EventTypes.Create, ""), None)
@@ -118,7 +122,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403,
"Alias event's state_key does not match sender's domain"
)
- return True
+ logger.debug("Allowing! %s", event)
+ return
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -127,14 +132,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
if event.type == EventTypes.Member:
- allowed = _is_membership_change_allowed(
- event, auth_events
- )
- if allowed:
- logger.debug("Allowing! %s", event)
- else:
- logger.debug("Denying! %s", event)
- return allowed
+ _is_membership_change_allowed(event, auth_events)
+ logger.debug("Allowing! %s", event)
+ return
_check_event_sender_in_room(event, auth_events)
@@ -153,7 +153,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
)
else:
- return True
+ logger.debug("Allowing! %s", event)
+ return
_can_send_event(event, auth_events)
@@ -200,7 +201,7 @@ def _is_membership_change_allowed(event, auth_events):
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
- return True
+ return
target_user_id = event.state_key
@@ -265,13 +266,13 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
- return True
+ return
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
- return True
+ return
if not caller_in_room: # caller isn't joined
raise AuthError(
@@ -319,7 +320,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
- 403, "You cannot unban user &s." % (target_user_id,)
+ 403, "You cannot unban user %s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
@@ -334,8 +335,6 @@ def _is_membership_change_allowed(event, auth_events):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- return True
-
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
@@ -355,35 +354,46 @@ def _check_joined_room(member, user_id, room_id):
))
-def get_send_level(etype, state_key, auth_events):
- key = (EventTypes.PowerLevels, "", )
- send_level_event = auth_events.get(key)
- send_level = None
- if send_level_event:
- send_level = send_level_event.content.get("events", {}).get(
- etype
- )
- if send_level is None:
- if state_key is not None:
- send_level = send_level_event.content.get(
- "state_default", 50
- )
- else:
- send_level = send_level_event.content.get(
- "events_default", 0
- )
+def get_send_level(etype, state_key, power_levels_event):
+ """Get the power level required to send an event of a given type
+
+ The federation spec [1] refers to this as "Required Power Level".
+
+ https://matrix.org/docs/spec/server_server/unstable.html#definitions
- if send_level:
- send_level = int(send_level)
+ Args:
+ etype (str): type of event
+ state_key (str|None): state_key of state event, or None if it is not
+ a state event.
+ power_levels_event (synapse.events.EventBase|None): power levels event
+ in force at this point in the room
+ Returns:
+ int: power level required to send this event.
+ """
+
+ if power_levels_event:
+ power_levels_content = power_levels_event.content
else:
- send_level = 0
+ power_levels_content = {}
+
+ # see if we have a custom level for this event type
+ send_level = power_levels_content.get("events", {}).get(etype)
+
+ # otherwise, fall back to the state_default/events_default.
+ if send_level is None:
+ if state_key is not None:
+ send_level = power_levels_content.get("state_default", 50)
+ else:
+ send_level = power_levels_content.get("events_default", 0)
- return send_level
+ return int(send_level)
def _can_send_event(event, auth_events):
+ power_levels_event = _get_power_level_event(auth_events)
+
send_level = get_send_level(
- event.type, event.get("state_key", None), auth_events
+ event.type, event.get("state_key"), power_levels_event,
)
user_level = get_user_power_level(event.user_id, auth_events)
@@ -443,12 +453,12 @@ def _check_power_levels(event, auth_events):
for k, v in user_list.items():
try:
UserID.from_string(k)
- except:
+ except Exception:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
- except:
+ except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
@@ -470,15 +480,15 @@ def _check_power_levels(event, auth_events):
("invite", None),
]
- old_list = current_state.content.get("users")
- for user in set(old_list.keys() + user_list.keys()):
+ old_list = current_state.content.get("users", {})
+ for user in set(list(old_list) + list(user_list)):
levels_to_check.append(
(user, "users")
)
- old_list = current_state.content.get("events")
- new_list = event.content.get("events")
- for ev_id in set(old_list.keys() + new_list.keys()):
+ old_list = current_state.content.get("events", {})
+ new_list = event.content.get("events", {})
+ for ev_id in set(list(old_list) + list(new_list)):
levels_to_check.append(
(ev_id, "events")
)
@@ -515,7 +525,11 @@ def _check_power_levels(event, auth_events):
"to your own"
)
- if old_level > user_level or new_level > user_level:
+ # Check if the old and new levels are greater than the user level
+ # (if defined)
+ old_level_too_big = old_level is not None and old_level > user_level
+ new_level_too_big = new_level is not None and new_level > user_level
+ if old_level_too_big or new_level_too_big:
raise AuthError(
403,
"You don't have permission to add ops level greater "
@@ -524,13 +538,22 @@ def _check_power_levels(event, auth_events):
def _get_power_level_event(auth_events):
- key = (EventTypes.PowerLevels, "", )
- return auth_events.get(key)
+ return auth_events.get((EventTypes.PowerLevels, ""))
def get_user_power_level(user_id, auth_events):
- power_level_event = _get_power_level_event(auth_events)
+ """Get a user's power level
+
+ Args:
+ user_id (str): user's id to look up in power_levels
+ auth_events (dict[(str, str), synapse.events.EventBase]):
+ state in force at this point in the room (or rather, a subset of
+ it including at least the create event and power levels event.
+ Returns:
+ int: the user's power level in this room.
+ """
+ power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
@@ -541,6 +564,11 @@ def get_user_power_level(user_id, auth_events):
else:
return int(level)
else:
+ # if there is no power levels event, the creator gets 100 and everyone
+ # else gets 0.
+
+ # some things which call this don't pass the create event: hack around
+ # that.
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..51f9084b90 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.frozenutils import freeze
from synapse.util.caches import intern_dict
-
+from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
@@ -47,14 +46,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key):
+ # We want to be able to use hasattr with the event dict properties.
+ # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+ # we need to transform the KeyError into an AttributeError
def getter(self):
- return self._event_dict[key]
+ try:
+ return self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
def setter(self, v):
- self._event_dict[key] = v
+ try:
+ self._event_dict[key] = v
+ except KeyError:
+ raise AttributeError(key)
def delete(self):
- del self._event_dict[key]
+ try:
+ del self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
return property(
getter,
@@ -134,7 +145,7 @@ class EventBase(object):
return field in self._event_dict
def items(self):
- return self._event_dict.items()
+ return list(self._event_dict.items())
class FrozenEvent(EventBase):
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 365fd96bd2..e662eaef10 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import EventBase, FrozenEvent, _event_dict_property
+import copy
from synapse.types import EventID
-
from synapse.util.stringutils import random_string
-import copy
+from . import EventBase, FrozenEvent, _event_dict_property
class EventBuilder(EventBase):
@@ -55,7 +54,7 @@ class EventBuilderFactory(object):
local_part = str(int(self.clock.time())) + i + random_string(5)
- e_id = EventID.create(local_part, self.hostname)
+ e_id = EventID(local_part, self.hostname)
return e_id.to_string()
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..368b5f6ae4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -13,19 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import iteritems
+
+from frozendict import frozendict
+
+from twisted.internet import defer
+
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
class EventContext(object):
"""
Attributes:
- current_state_ids (dict[(str, str), str]):
- The current state map including the current event.
- (type, state_key) -> event_id
-
- prev_state_ids (dict[(str, str), str]):
- The current state map excluding the current event.
- (type, state_key) -> event_id
-
- state_group (int): state group id
+ state_group (int|None): state group id, if the state has been stored
+ as a state group. This is usually only None if e.g. the event is
+ an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else
False
@@ -39,35 +41,250 @@ class EventContext(object):
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
+
+ _current_state_ids (dict[(str, str), str]|None):
+ The current state map including the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _prev_state_ids (dict[(str, str), str]|None):
+ The current state map excluding the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+ Only set when state has not been fetched yet.
+
+ _event_state_key (str|None): The state_key of the event the context is
+ associated with. Only set when state has not been fetched yet.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ Only set when state has not been fetched yet.
"""
__slots__ = [
- "current_state_ids",
- "prev_state_ids",
"state_group",
"rejected",
- "push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
"app_service",
+ "_current_state_ids",
+ "_prev_state_ids",
+ "_prev_state_id",
+ "_event_type",
+ "_event_state_key",
+ "_fetching_state_deferred",
]
def __init__(self):
+ self.prev_state_events = []
+ self.rejected = False
+ self.app_service = None
+
+ @staticmethod
+ def with_state(state_group, current_state_ids, prev_state_ids,
+ prev_group=None, delta_ids=None):
+ context = EventContext()
+
# The current state including the current event
- self.current_state_ids = None
+ context._current_state_ids = current_state_ids
# The current state excluding the current event
- self.prev_state_ids = None
- self.state_group = None
+ context._prev_state_ids = prev_state_ids
+ context.state_group = state_group
- self.rejected = False
- self.push_actions = []
+ context._prev_state_id = None
+ context._event_type = None
+ context._event_state_key = None
+ context._fetching_state_deferred = defer.succeed(None)
# A previously persisted state group and a delta between that
# and this state.
- self.prev_group = None
- self.delta_ids = None
+ context.prev_group = prev_group
+ context.delta_ids = delta_ids
- self.prev_state_events = None
+ return context
- self.app_service = None
+ @defer.inlineCallbacks
+ def serialize(self, event, store):
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Args:
+ event (FrozenEvent): The event that this context relates to
+
+ Returns:
+ dict
+ """
+
+ # We don't serialize the full state dicts, instead they get pulled out
+ # of the DB on the other side. However, the other side can't figure out
+ # the prev_state_ids, so if we're a state event we include the event
+ # id that we replaced in the state.
+ if event.is_state():
+ prev_state_ids = yield self.get_prev_state_ids(store)
+ prev_state_id = prev_state_ids.get((event.type, event.state_key))
+ else:
+ prev_state_id = None
+
+ defer.returnValue({
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None
+ })
+
+ @staticmethod
+ def deserialize(store, input):
+ """Converts a dict that was produced by `serialize` back into a
+ EventContext.
+
+ Args:
+ store (DataStore): Used to convert AS ID to AS object
+ input (dict): A dict produced by `serialize`
+
+ Returns:
+ EventContext
+ """
+ context = EventContext()
+
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ context._prev_state_id = input["prev_state_id"]
+ context._event_type = input["event_type"]
+ context._event_state_key = input["event_state_key"]
+
+ context._current_state_ids = None
+ context._prev_state_ids = None
+ context._fetching_state_deferred = None
+
+ context.state_group = input["state_group"]
+ context.prev_group = input["prev_group"]
+ context.delta_ids = _decode_state_dict(input["delta_ids"])
+
+ context.rejected = input["rejected"]
+ context.prev_state_events = input["prev_state_events"]
+
+ app_service_id = input["app_service_id"]
+ if app_service_id:
+ context.app_service = store.get_app_service_by_id(app_service_id)
+
+ return context
+
+ @defer.inlineCallbacks
+ def get_current_state_ids(self, store):
+ """Gets the current state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._current_state_ids)
+
+ @defer.inlineCallbacks
+ def get_prev_state_ids(self, store):
+ """Gets the prev state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._prev_state_ids)
+
+ def get_cached_current_state_ids(self):
+ """Gets the current state IDs if we have them already cached.
+
+ Returns:
+ dict[(str, str), str]|None: Returns None if we haven't cached the
+ state or if state_group is None, which happens when the associated
+ event is an outlier.
+ """
+
+ return self._current_state_ids
+
+ @defer.inlineCallbacks
+ def _fill_out_state(self, store):
+ """Called to populate the _current_state_ids and _prev_state_ids
+ attributes by loading from the database.
+ """
+ if self.state_group is None:
+ return
+
+ self._current_state_ids = yield store.get_state_ids_for_group(
+ self.state_group,
+ )
+ if self._prev_state_id and self._event_state_key is not None:
+ self._prev_state_ids = dict(self._current_state_ids)
+
+ key = (self._event_type, self._event_state_key)
+ self._prev_state_ids[key] = self._prev_state_id
+ else:
+ self._prev_state_ids = self._current_state_ids
+
+ @defer.inlineCallbacks
+ def update_state(self, state_group, prev_state_ids, current_state_ids,
+ prev_group, delta_ids):
+ """Replace the state in the context
+ """
+
+ # We need to make sure we wait for any ongoing fetching of state
+ # to complete so that the updated state doesn't get clobbered
+ if self._fetching_state_deferred:
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ self.state_group = state_group
+ self._prev_state_ids = prev_state_ids
+ self.prev_group = prev_group
+ self._current_state_ids = current_state_ids
+ self.delta_ids = delta_ids
+
+ # We need to ensure that that we've marked as having fetched the state
+ self._fetching_state_deferred = defer.succeed(None)
+
+
+def _encode_state_dict(state_dict):
+ """Since dicts of (type, state_key) -> event_id cannot be serialized in
+ JSON we need to convert them to a form that can.
+ """
+ if state_dict is None:
+ return None
+
+ return [
+ (etype, state_key, v)
+ for (etype, state_key), v in iteritems(state_dict)
+ ]
+
+
+def _decode_state_dict(input):
+ """Decodes a state dict encoded using `_encode_state_dict` above
+ """
+ if input is None:
+ return None
+
+ return frozendict({(etype, state_key,): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
new file mode 100644
index 0000000000..633e068eb8
--- /dev/null
+++ b/synapse/events/spamcheck.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class SpamChecker(object):
+ def __init__(self, hs):
+ self.spam_checker = None
+
+ module = None
+ config = None
+ try:
+ module, config = hs.config.spam_checker
+ except Exception:
+ pass
+
+ if module is not None:
+ self.spam_checker = module(config=config)
+
+ def check_event_for_spam(self, event):
+ """Checks if a given event is considered "spammy" by this server.
+
+ If the server considers an event spammy, then it will be rejected if
+ sent by a local user. If it is sent by a user on another server, then
+ users receive a blank event.
+
+ Args:
+ event (synapse.events.EventBase): the event to be checked
+
+ Returns:
+ bool: True if the event is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ return self.spam_checker.check_event_for_spam(event)
+
+ def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ """Checks if a given user may send an invite
+
+ If this method returns false, the invite will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may send an invite, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+
+ def user_may_create_room(self, userid):
+ """Checks if a given user may create a room
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may create a room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room(userid)
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Checks if a given user may create a room alias
+
+ If this method returns false, the association request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_alias (string): The alias to be created
+
+ Returns:
+ bool: True if the user may create a room alias, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+
+ def user_may_publish_room(self, userid, room_id):
+ """Checks if a given user may publish a room to the directory
+
+ If this method returns false, the publish request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_id (string): The ID of the room that would be published
+
+ Returns:
+ bool: True if the user may publish the room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_publish_room(userid, room_id)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 824f4a42e3..652941ca0d 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
-from . import EventBase
+import re
+
+from six import string_types
from frozendict import frozendict
-import re
+from synapse.api.constants import EventTypes
+
+from . import EventBase
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
@@ -277,7 +280,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields:
if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, basestring) for f in only_event_fields)):
+ not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 2f4c8a1018..cf184748a1 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.types import EventID, RoomID, UserID
-from synapse.api.errors import SynapseError
+from six import string_types
+
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
@@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
- if not isinstance(getattr(event, s), basestring):
+ if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
@@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
- if not isinstance(d[s], basestring):
+ if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
""" This package includes all the federation specific logic.
"""
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
- transport = hs.get_federation_transport_client()
-
- return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..c11798093d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import six
from twisted.internet import defer
-from synapse.events.utils import prune_event
-
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
-
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_dict
+from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.hs = hs
+
+ self.server_name = hs.hostname
+ self.keyring = hs.get_keyring()
+ self.spam_checker = hs.get_spam_checker()
+ self.store = hs.get_datastore()
+ self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +61,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
-
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
+
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -114,15 +114,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
- """Throws a SynapseError if a PDU does not have the correct
- signatures.
+ """Checks that each of the received events is correctly signed by the
+ sending server.
+
+ Args:
+ pdus (list[FrozenEvent]): the events to be checked
Returns:
- FrozenEvent: Either the given event or it redacted if it failed the
- content hash check.
+ list[Deferred]: for each input event, a deferred which:
+ * returns the original event if the checks pass
+ * returns a redacted version of the event (if the signature
+ matched but the hash did not)
+ * throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,26 +139,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
+ ctx = logcontext.LoggingContext.current_context()
+
def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
- return pdu
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if self.spam_checker.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
@@ -160,3 +181,40 @@ class FederationBase(object):
)
return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+ """Construct a FrozenEvent from an event json received over federation
+
+ Args:
+ pdu_json (object): pdu as received over federation
+ outlier (bool): True to mark this event as an outlier
+
+ Returns:
+ FrozenEvent
+
+ Raises:
+ SynapseError: if the pdu is missing required fields or is otherwise
+ not a valid matrix event
+ """
+ # we could probably enforce a bunch of other fields here (room_id, sender,
+ # origin, etc etc)
+ assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
+
+ depth = pdu_json['depth']
+ if not isinstance(depth, six.integer_types):
+ raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+ Codes.BAD_JSON)
+
+ if depth < 0:
+ raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+ elif depth > MAX_DEPTH:
+ raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 861441708b..62d7ed13cf 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,36 +14,35 @@
# limitations under the License.
+import copy
+import itertools
+import logging
+import random
+
+from six.moves import range
+
+from prometheus_client import Counter
+
from twisted.internet import defer
-from .federation_base import FederationBase
from synapse.api.constants import Membership
-
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError,
+ CodeMessageException,
+ FederationDeniedError,
+ HttpResponseException,
+ SynapseError,
)
-from synapse.util import unwrapFirstError
+from synapse.events import builder
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
from synapse.util.retryutils import NotRetryingDestination
-import copy
-import itertools
-import logging
-import random
-
-
logger = logging.getLogger(__name__)
-
-# synapse.federation.federation_client is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-
-sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
@@ -58,6 +57,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
+ self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
@@ -105,7 +105,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc(query_type)
+ sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
@@ -124,7 +124,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_device_keys")
+ sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(
destination, content, timeout
)
@@ -134,7 +134,7 @@ class FederationClient(FederationBase):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
- sent_queries_counter.inc("user_devices")
+ sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@@ -151,7 +151,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_one_time_keys")
+ sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@@ -184,15 +184,15 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
- self.event_from_pdu_json(p, outlier=False)
+ event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+ pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(pdus)
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
+ event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
@@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ signed_pdu = yield self._check_sigs_and_hash(pdu)
break
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e.message)
+ continue
except Exception as e:
pdu_attempts[destination] = now
@@ -336,11 +339,11 @@ class FederationClient(FederationBase):
)
pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -388,9 +391,9 @@ class FederationClient(FederationBase):
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = seen_events.values()
+ signed_events = list(seen_events.values())
else:
- seen_events = yield self.store.have_events(event_ids)
+ seen_events = yield self.store.have_seen_events(event_ids)
signed_events = []
failed_to_fetch = set()
@@ -409,18 +412,19 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
- for i in xrange(0, len(missing_events), batch_size):
+ for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- preserve_fn(self.get_pdu)(
+ run_in_background(
+ self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
- res = yield preserve_context_over_deferred(
+ res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
@@ -441,7 +445,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
@@ -570,12 +574,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -585,7 +589,7 @@ class FederationClient(FederationBase):
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus.values(),
+ destination, list(pdus.values()),
outlier=True,
)
@@ -650,7 +654,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict)
- pdu = self.event_from_pdu_json(pdu_dict)
+ pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +744,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -788,7 +792,7 @@ class FederationClient(FederationBase):
)
events = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content.get("events", [])
]
@@ -805,15 +809,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 51e3fdea06..e501251b6e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,92 +13,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import re
+import six
+from six import iteritems
-from twisted.internet import defer
-
-from .federation_base import FederationBase
-from .units import Transaction, Edu
-
-from synapse.util.async import Linearizer
-from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
-import synapse.metrics
+from canonicaljson import json
+from prometheus_client import Counter
-from synapse.api.errors import AuthError, FederationError, SynapseError
+from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
+from twisted.python import failure
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
+from synapse.types import get_domain_from_id
+from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logutils import log_function
-import simplejson as json
-import logging
-
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
-# synapse.federation.federation_server is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
-
-received_pdus_counter = metrics.register_counter("received_pdus")
+received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
-received_edus_counter = metrics.register_counter("received_edus")
+received_edus_counter = Counter("synapse_federation_server_received_edus", "")
-received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+received_queries_counter = Counter(
+ "synapse_federation_server_received_queries", "", ["type"]
+)
class FederationServer(FederationBase):
+
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
+ self.handler = hs.get_handlers().federation_handler
- self._server_linearizer = Linearizer("fed_server")
+ self._server_linearizer = async.Linearizer("fed_server")
+ self._transaction_linearizer = async.Linearizer("fed_txn_handler")
- # We cache responses to state queries, as they take a while and often
- # come in waves.
- self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+ self.transaction_actions = TransactionActions(self.store)
- def set_handler(self, handler):
- """Sets the handler that the replication layer will use to communicate
- receipt of new PDUs from other home servers. The required methods are
- documented on :py:class:`.ReplicationHandler`.
- """
- self.handler = handler
+ self.registry = hs.get_federation_registry()
- def register_edu_handler(self, edu_type, handler):
- if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
- self.edu_handlers[edu_type] = handler
-
- def register_query_handler(self, query_type, handler):
- """Sets the handler callable that will be used to handle an incoming
- federation Query of the given type.
-
- Args:
- query_type (str): Category name of the query, which should match
- the string used by make_query.
- handler (callable): Invoked to handle incoming queries of this type
-
- handler is invoked as:
- result = handler(args)
-
- where 'args' is a dict mapping strings to strings of the query
- arguments. It should return a Deferred that will eventually yield an
- object to encode as JSON.
- """
- if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
-
- self.query_handlers[query_type] = handler
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -109,25 +90,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
+ # keep this as early as possible to make the calculated origin ts as
+ # accurate as possible.
+ request_time = self._clock.time_msec()
+
transaction = Transaction(**transaction_data)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ if not transaction.transaction_id:
+ raise Exception("Transaction missing transaction_id")
+ if not transaction.origin:
+ raise Exception("Transaction missing origin")
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
+ # use a linearizer to ensure that we don't process the same transaction
+ # multiple times in parallel.
+ with (yield self._transaction_linearizer.queue(
+ (transaction.origin, transaction.transaction_id),
+ )):
+ result = yield self._handle_incoming_transaction(
+ transaction, request_time,
+ )
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _handle_incoming_transaction(self, transaction, request_time):
+ """ Process an incoming transaction and return the HTTP response
+
+ Args:
+ transaction (Transaction): incoming transaction
+ request_time (int): timestamp that the HTTP request arrived at
+ Returns:
+ Deferred[(int, object)]: http response code and body
+ """
response = yield self.transaction_actions.have_responded(transaction)
if response:
@@ -140,42 +137,67 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- results = []
-
- for pdu in pdu_list:
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if transaction.origin != get_domain_from_id(pdu.event_id):
- # We continue to accept join events from any server; this is
- # necessary for the federation join dance to work correctly.
- # (When we join over federation, the "helper" server is
- # responsible for sending out the join event, rather than the
- # origin. See bug #1893).
- if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) == 'join'
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, transaction.origin
+ received_pdus_counter.inc(len(transaction.pdus))
+
+ origin_host, _ = parse_server_name(transaction.origin)
+
+ pdus_by_room = {}
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = request_time - int(p["age"])
+ del p["age"]
+
+ event = event_from_pdu_json(p)
+ room_id = event.room_id
+ pdus_by_room.setdefault(room_id, []).append(event)
+
+ pdu_results = {}
+
+ # we can process different rooms in parallel (which is useful if they
+ # require callouts to other servers to fetch missing events), but
+ # impose a limit to avoid going too crazy with ram/cpu.
+
+ @defer.inlineCallbacks
+ def process_pdus_for_room(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
+ try:
+ yield self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warn(
+ "Ignoring PDUs for room %s from banned server", room_id,
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ try:
+ yield self._handle_received_pdu(
+ transaction.origin, pdu
)
- continue
- else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, transaction.origin
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warn("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ pdu_results[event_id] = {"error": str(e)}
+ logger.error(
+ "Failed to handle PDU %s: %s",
+ event_id, f.getTraceback().rstrip(),
)
- try:
- yield self._handle_received_pdu(transaction.origin, pdu)
- results.append({})
- except FederationError as e:
- self.send_failure(e, transaction.origin)
- results.append({"error": str(e)})
- except Exception as e:
- results.append({"error": str(e)})
- logger.exception("Failed to handle PDU")
+ yield async.concurrently_execute(
+ process_pdus_for_room, pdus_by_room.keys(),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
@@ -185,17 +207,16 @@ class FederationServer(FederationBase):
edu.content
)
- for failure in getattr(transaction, "pdu_failures", []):
- logger.info("Got failure %r", failure)
-
- logger.debug("Returning: %s", str(results))
+ pdu_failures = getattr(transaction, "pdu_failures", [])
+ for fail in pdu_failures:
+ logger.info("Got failure %r", fail)
response = {
- "pdus": dict(zip(
- (p.event_id for p in pdu_list), results
- )),
+ "pdus": pdu_results,
}
+ logger.debug("Returning: %s", str(response))
+
yield self.transaction_actions.set_response(
transaction,
200, response
@@ -205,16 +226,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
-
- if edu_type in self.edu_handlers:
- try:
- yield self.edu_handlers[edu_type](origin, content)
- except SynapseError as e:
- logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception as e:
- logger.exception("Failed to handle edu %r", edu_type)
- else:
- logger.warn("Received EDU of type %s with no handler", edu_type)
+ yield self.registry.on_edu(edu_type, origin, content)
@defer.inlineCallbacks
@log_function
@@ -222,19 +234,24 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- result = self._state_resp_cache.get((room_id, event_id))
- if not result:
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.set(
- (room_id, event_id),
- self._on_context_state_request_compute(room_id, event_id)
- )
- else:
- resp = yield result
+ # we grab the linearizer to protect ourselves from servers which hammer
+ # us. In theory we might already have the response to this query
+ # in the cache so we could return it without waiting for the linearizer
+ # - but that's non-trivial to get right, and anyway somewhat defeats
+ # the point of the linearizer.
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id, event_id,
+ )
defer.returnValue((200, resp))
@@ -243,6 +260,9 @@ class FederationServer(FederationBase):
if not event_id:
raise NotImplementedError("Specify an event")
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -286,7 +306,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
+ pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -302,25 +322,23 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
- received_queries_counter.inc(query_type)
-
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue(
- (404, "No handler for Query type '%s'" % (query_type,))
- )
+ received_queries_counter.labels(query_type).inc()
+ resp = yield self.registry.on_query(query_type, args)
+ defer.returnValue((200, resp))
@defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -328,7 +346,11 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -340,7 +362,9 @@ class FederationServer(FederationBase):
}))
@defer.inlineCallbacks
- def on_make_leave_request(self, room_id, user_id):
+ def on_make_leave_request(self, origin, room_id, user_id):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -348,7 +372,11 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
+ pdu = event_from_pdu_json(content)
+
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -356,6 +384,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {
@@ -384,8 +415,11 @@ class FederationServer(FederationBase):
Deferred: Results in `dict` with the same format as `content`
"""
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
auth_chain = [
- self.event_from_pdu_json(e)
+ event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -444,9 +478,9 @@ class FederationServer(FederationBase):
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -457,6 +491,9 @@ class FederationServer(FederationBase):
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))):
+ origin_host, _ = parse_server_name(origin)
+ yield self.check_server_matches_acl(origin_host, room_id)
+
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d",
@@ -485,17 +522,6 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
@@ -513,13 +539,57 @@ class FederationServer(FederationBase):
def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
+ If the event is invalid, then this method throws a FederationError.
+ (The error will then be logged and sent back to the sender (which
+ probably won't do anything with it), and other events in the
+ transaction will be processed as normal).
+
+ It is likely that we'll then receive other events which refer to
+ this rejected_event in their prev_events, etc. When that happens,
+ we'll attempt to fetch the rejected event again, which will presumably
+ fail, so those second-generation events will also get rejected.
+
+ Eventually, we get to the point where there are more than 10 events
+ between any new events and the original rejected event. Since we
+ only try to backfill 10 events deep on received pdu, we then accept the
+ new event, possibly introducing a discontinuity in the DAG, with new
+ forward extremities, so normal service is approximately returned,
+ until we try to backfill across the discontinuity.
+
Args:
origin (str): server which sent the pdu
pdu (FrozenEvent): received pdu
Returns (Deferred): completes with None
- Raises: FederationError if the signatures / hash do not match
- """
+
+ Raises: FederationError if the signatures / hash do not match, or
+ if the event was unacceptable for any other reason (eg, too large,
+ too many prev_events, couldn't find the prev_events)
+ """
+ # check that it's actually being sent from a valid destination to
+ # workaround bug #1753 in 0.18.5 and 0.18.6
+ if origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
+ if not (
+ pdu.type == 'm.room.member' and
+ pdu.content and
+ pdu.content.get("membership", None) == 'join'
+ ):
+ logger.info(
+ "Discarding PDU %s from invalid origin %s",
+ pdu.event_id, origin
+ )
+ return
+ else:
+ logger.info(
+ "Accepting join PDU %s from %s",
+ pdu.event_id, origin
+ )
+
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
@@ -531,20 +601,13 @@ class FederationServer(FederationBase):
affected=pdu.event_id,
)
- yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
+ yield self.handler.on_receive_pdu(
+ origin, pdu, get_missing=True, sent_to_us_directly=True,
+ )
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
@defer.inlineCallbacks
def exchange_third_party_invite(
self,
@@ -567,3 +630,161 @@ class FederationServer(FederationBase):
origin, room_id, event_dict
)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def check_server_matches_acl(self, server_name, room_id):
+ """Check if the given server is allowed by the server ACLs in the room
+
+ Args:
+ server_name (str): name of server, *without any port part*
+ room_id (str): ID of the room to check
+
+ Raises:
+ AuthError if the server does not match the ACL
+ """
+ state_ids = yield self.store.get_current_state_ids(room_id)
+ acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+ if not acl_event_id:
+ return
+
+ acl_event = yield self.store.get_event(acl_event_id)
+ if server_matches_acl_event(server_name, acl_event):
+ return
+
+ raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+ """Check if the given server is allowed by the ACL event
+
+ Args:
+ server_name (str): name of server, without any port part
+ acl_event (EventBase): m.room.server_acl event
+
+ Returns:
+ bool: True if this server is allowed by the ACLs
+ """
+ logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+ # first of all, check if literal IPs are blocked, and if so, whether the
+ # server name is a literal IP
+ allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+ if not isinstance(allow_ip_literals, bool):
+ logger.warn("Ignorning non-bool allow_ip_literals flag")
+ allow_ip_literals = True
+ if not allow_ip_literals:
+ # check for ipv6 literals. These start with '['.
+ if server_name[0] == '[':
+ return False
+
+ # check for ipv4 literals. We can just lift the routine from twisted.
+ if isIPAddress(server_name):
+ return False
+
+ # next, check the deny list
+ deny = acl_event.content.get("deny", [])
+ if not isinstance(deny, (list, tuple)):
+ logger.warn("Ignorning non-list deny ACL %s", deny)
+ deny = []
+ for e in deny:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched deny rule %s", server_name, e)
+ return False
+
+ # then the allow list.
+ allow = acl_event.content.get("allow", [])
+ if not isinstance(allow, (list, tuple)):
+ logger.warn("Ignorning non-list allow ACL %s", allow)
+ allow = []
+ for e in allow:
+ if _acl_entry_matches(server_name, e):
+ # logger.info("%s matched allow rule %s", server_name, e)
+ return True
+
+ # everything else should be rejected.
+ # logger.info("%s fell through", server_name)
+ return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+ if not isinstance(acl_entry, six.string_types):
+ logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ return False
+ regex = _glob_to_regex(acl_entry)
+ return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+ return re.compile(res + "\\Z", re.IGNORECASE)
+
+
+class FederationHandlerRegistry(object):
+ """Allows classes to register themselves as handlers for a given EDU or
+ query type for incoming federation traffic.
+ """
+ def __init__(self):
+ self.edu_handlers = {}
+ self.query_handlers = {}
+
+ def register_edu_handler(self, edu_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation EDU of the given type.
+
+ Args:
+ edu_type (str): The type of the incoming EDU to register handler for
+ handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+ of the given type. The arguments are the origin server name and
+ the EDU contents.
+ """
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+ self.edu_handlers[edu_type] = handler
+
+ def register_query_handler(self, query_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation query of the given type.
+
+ Args:
+ query_type (str): Category name of the query, which should match
+ the string used by make_query.
+ handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+ incoming queries of this type. The return will be yielded
+ on and the result used as the response to the query request.
+ """
+ if query_type in self.query_handlers:
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
+
+ self.query_handlers[query_type] = handler
+
+ @defer.inlineCallbacks
+ def on_edu(self, edu_type, origin, content):
+ handler = self.edu_handlers.get(edu_type)
+ if not handler:
+ logger.warn("No handler registered for EDU type %s", edu_type)
+
+ try:
+ yield handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception as e:
+ logger.exception("Failed to handle edu %r", edu_type)
+
+ def on_query(self, query_type, args):
+ handler = self.query_handlers.get(query_type)
+ if not handler:
+ logger.warn("No handler registered for query type %s", query_type)
+ raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+ return handler(args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 84dc606673..9146215c21 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -19,13 +19,12 @@ package.
These actions are mostly only used by the :py:mod:`.replication` module.
"""
+import logging
+
from twisted.internet import defer
from synapse.util.logutils import log_function
-import logging
-
-
logger = logging.getLogger(__name__)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
- """This layer is responsible for replicating with remote home servers over
- the given transport. I.e., does the sending and receiving of PDUs to
- remote home servers.
-
- The layer communicates with the rest of the server via a registered
- ReplicationHandler.
-
- In more detail, the layer:
- * Receives incoming data and processes it into transactions and pdus.
- * Fetches any PDUs it thinks it might have missed.
- * Keeps the current state for contexts up to date by applying the
- suitable conflict resolution.
- * Sends outgoing pdus wrapped in transactions.
- * Fills out the references to previous pdus/transactions appropriately
- for outgoing data.
- """
-
- def __init__(self, hs, transport_layer):
- self.server_name = hs.hostname
-
- self.keyring = hs.get_keyring()
-
- self.transport_layer = transport_layer
-
- self.federation_client = self
-
- self.store = hs.get_datastore()
-
- self.handler = None
- self.edu_handlers = {}
- self.query_handlers = {}
-
- self._clock = hs.get_clock()
-
- self.transaction_actions = TransactionActions(self.store)
-
- self.hs = hs
-
- super(ReplicationLayer, self).__init__(hs)
-
- def __str__(self):
- return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 93e5acebc1..5157c3860d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -29,23 +29,22 @@ dead worker doesn't cause the queues to grow limitlessly.
Events are replicated via a separate events stream.
"""
-from .units import Edu
+import logging
+from collections import namedtuple
+from six import iteritems, itervalues
+
+from sortedcontainers import SortedDict
+
+from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
-import synapse.metrics
-
-from blist import sorteddict
-from collections import namedtuple
-import logging
+from .units import Edu
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-
class FederationRemoteSendQueue(object):
"""A drop in replacement for TransactionQueue"""
@@ -56,29 +55,27 @@ class FederationRemoteSendQueue(object):
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = sorteddict() # Stream position -> user_id
+ self.presence_changed = SortedDict() # Stream position -> user_id
self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = sorteddict() # stream position -> (destination, key)
+ self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
- self.edus = sorteddict() # stream position -> Edu
+ self.edus = SortedDict() # stream position -> Edu
- self.failures = sorteddict() # stream position -> (destination, Failure)
+ self.failures = SortedDict() # stream position -> (destination, Failure)
- self.device_messages = sorteddict() # stream position -> destination
+ self.device_messages = SortedDict() # stream position -> destination
self.pos = 1
- self.pos_time = sorteddict()
+ self.pos_time = SortedDict()
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- metrics.register_callback(
- queue_name + "_size",
- lambda: len(queue),
- )
+ LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
+ "", [], lambda: len(queue))
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
@@ -101,7 +98,7 @@ class FederationRemoteSendQueue(object):
now = self.clock.time_msec()
keys = self.pos_time.keys()
- time = keys.bisect_left(now - FIVE_MINUTES_AGO)
+ time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
if not keys[:time]:
return
@@ -116,13 +113,13 @@ class FederationRemoteSendQueue(object):
with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps
keys = self.presence_changed.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.presence_changed.bisect_left(position_to_delete)
for key in keys[:i]:
del self.presence_changed[key]
user_ids = set(
user_id
- for uids in self.presence_changed.itervalues()
+ for uids in itervalues(self.presence_changed)
for user_id in uids
)
@@ -134,7 +131,7 @@ class FederationRemoteSendQueue(object):
# Delete things out of keyed edus
keys = self.keyed_edu_changed.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.keyed_edu_changed.bisect_left(position_to_delete)
for key in keys[:i]:
del self.keyed_edu_changed[key]
@@ -148,19 +145,19 @@ class FederationRemoteSendQueue(object):
# Delete things out of edu map
keys = self.edus.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.edus.bisect_left(position_to_delete)
for key in keys[:i]:
del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.failures.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map
keys = self.device_messages.keys()
- i = keys.bisect_left(position_to_delete)
+ i = self.device_messages.bisect_left(position_to_delete)
for key in keys[:i]:
del self.device_messages[key]
@@ -200,7 +197,7 @@ class FederationRemoteSendQueue(object):
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+ local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states]
@@ -253,13 +250,12 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(federation_ack)
# Fetch changed presence
- keys = self.presence_changed.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
+ i = self.presence_changed.bisect_right(from_token)
+ j = self.presence_changed.bisect_right(to_token) + 1
dest_user_ids = [
(pos, user_id)
- for pos in keys[i:j]
- for user_id in self.presence_changed[pos]
+ for pos, user_id_list in self.presence_changed.items()[i:j]
+ for user_id in user_id_list
]
for (key, user_id) in dest_user_ids:
@@ -268,34 +264,31 @@ class FederationRemoteSendQueue(object):
)))
# Fetch changes keyed edus
- keys = self.keyed_edu_changed.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
+ i = self.keyed_edu_changed.bisect_right(from_token)
+ j = self.keyed_edu_changed.bisect_right(to_token) + 1
# We purposefully clobber based on the key here, python dict comprehensions
# always use the last value, so this will correctly point to the last
# stream position.
- keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+ keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in keyed_edus.iteritems():
+ for ((destination, edu_key), pos) in iteritems(keyed_edus):
rows.append((pos, KeyedEduRow(
key=edu_key,
edu=self.keyed_edu[(destination, edu_key)],
)))
# Fetch changed edus
- keys = self.edus.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- edus = ((k, self.edus[k]) for k in keys[i:j])
+ i = self.edus.bisect_right(from_token)
+ j = self.edus.bisect_right(to_token) + 1
+ edus = self.edus.items()[i:j]
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
# Fetch changed failures
- keys = self.failures.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- failures = ((k, self.failures[k]) for k in keys[i:j])
+ i = self.failures.bisect_right(from_token)
+ j = self.failures.bisect_right(to_token) + 1
+ failures = self.failures.items()[i:j]
for (pos, (destination, failure)) in failures:
rows.append((pos, FailureRow(
@@ -304,12 +297,11 @@ class FederationRemoteSendQueue(object):
)))
# Fetch changed device messages
- keys = self.device_messages.keys()
- i = keys.bisect_right(from_token)
- j = keys.bisect_right(to_token) + 1
- device_messages = {self.device_messages[k]: k for k in keys[i:j]}
+ i = self.device_messages.bisect_right(from_token)
+ j = self.device_messages.bisect_right(to_token) + 1
+ device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
- for (destination, pos) in device_messages.iteritems():
+ for (destination, pos) in iteritems(device_messages):
rows.append((pos, DeviceRow(
destination=destination,
)))
@@ -528,19 +520,19 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
- for destination, edu_map in buff.keyed_edus.iteritems():
+ for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=key,
)
- for destination, edu_list in buff.edus.iteritems():
+ for destination, edu_list in iteritems(buff.edus):
for edu in edu_list:
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in buff.failures.iteritems():
+ for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 003eaba893..6996d6b695 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -13,34 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
+import logging
-from twisted.internet import defer
+from six import itervalues
-from .persistence import TransactionActions
-from .units import Transaction, Edu
+from prometheus_client import Counter
-from synapse.api.errors import HttpResponseException
-from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
-from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
-from synapse.util.metrics import measure_func
-from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
-import synapse.metrics
+from twisted.internet import defer
-import logging
+import synapse.metrics
+from synapse.api.errors import FederationDeniedError, HttpResponseException
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
+from synapse.metrics import (
+ LaterGauge,
+ events_processed_counter,
+ sent_edus_counter,
+ sent_transactions_counter,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import logcontext
+from synapse.util.metrics import measure_func
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+from .persistence import TransactionActions
+from .units import Edu, Transaction
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = client_metrics.register_distribution(
- "sent_pdu_destinations"
+sent_pdus_destination_dist_count = Counter(
+ "synapse_federation_client_sent_pdu_destinations:count", ""
+)
+sent_pdus_destination_dist_total = Counter(
+ "synapse_federation_client_sent_pdu_destinations:total", ""
)
-sent_edus_counter = client_metrics.register_counter("sent_edus")
-
-sent_transactions_counter = client_metrics.register_counter("sent_transactions")
class TransactionQueue(object):
@@ -67,8 +72,10 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
- metrics.register_callback(
- "pending_destinations",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_destinations",
+ "",
+ [],
lambda: len(self.pending_transactions),
)
@@ -92,12 +99,16 @@ class TransactionQueue(object):
# Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
- metrics.register_callback(
- "pending_pdus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_pdus",
+ "",
+ [],
lambda: sum(map(len, pdus.values())),
)
- metrics.register_callback(
- "pending_edus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_edus",
+ "",
+ [],
lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
@@ -146,7 +157,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- @defer.inlineCallbacks
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
@@ -156,12 +166,20 @@ class TransactionQueue(object):
if self._is_processing:
return
+ # fire off a processing loop in the background
+ run_as_background_process(
+ "process_event_queue_for_federation",
+ self._process_event_queue_loop,
+ )
+
+ @defer.inlineCallbacks
+ def _process_event_queue_loop(self):
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=20,
+ last_token, self._last_poked_id, limit=100,
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -169,24 +187,33 @@ class TransactionQueue(object):
if not events and next_token >= self._last_poked_id:
break
- for event in events:
+ @defer.inlineCallbacks
+ def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
if not is_mine and send_on_behalf_of is None:
- continue
-
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=[
- prev_id for prev_id, _ in event.prev_events
- ],
- )
+ return
+
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = yield self.state.get_current_hosts_in_room(
+ event.room_id, latest_event_ids=[
+ prev_id for prev_id, _ in event.prev_events
+ ],
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
+
destinations = set(destinations)
if send_on_behalf_of is not None:
@@ -199,10 +226,41 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ events_by_room = {}
+ for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True
+ ))
+
yield self.store.update_federation_out_pos(
"events", next_token
)
+ if events:
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_lag.labels(
+ "federation_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "federation_sender").set(ts)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_positions.labels(
+ "federation_sender").set(next_token)
+
finally:
self._is_processing = False
@@ -224,18 +282,17 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc_by(len(destinations))
+ sent_pdus_destination_dist_total.inc(len(destinations))
+ sent_pdus_destination_dist_count.inc()
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
- @preserve_fn # the caller should not yield on this
+ @logcontext.preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
"""Send the new presence states to the appropriate destinations.
@@ -273,7 +330,9 @@ class TransactionQueue(object):
if not states_map:
break
- yield self._process_presence_inner(states_map.values())
+ yield self._process_presence_inner(list(states_map.values()))
+ except Exception:
+ logger.exception("Error sending presence states to servers")
finally:
self._processing_pending_presence = False
@@ -299,7 +358,7 @@ class TransactionQueue(object):
state.user_id: state for state in states
})
- preserve_fn(self._attempt_new_transaction)(destination)
+ self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
@@ -321,9 +380,7 @@ class TransactionQueue(object):
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
@@ -336,9 +393,7 @@ class TransactionQueue(object):
destination, []
).append(failure)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
@@ -347,15 +402,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination):
return
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def get_current_token(self):
return 0
- @defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
+ """Try to start a new transaction to this destination
+
+ If there is already a transaction in progress to this destination,
+ returns immediately. Otherwise kicks off the process of sending a
+ transaction in the background.
+
+ Args:
+ destination (str):
+
+ Returns:
+ None
+ """
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
@@ -368,6 +432,16 @@ class TransactionQueue(object):
)
return
+ logger.debug("TX [%s] Starting transaction loop", destination)
+
+ run_as_background_process(
+ "federation_transaction_transmission_loop",
+ self._transaction_transmission_loop,
+ destination,
+ )
+
+ @defer.inlineCallbacks
+ def _transaction_transmission_loop(self, destination):
pending_pdus = []
try:
self.pending_transactions[destination] = 1
@@ -377,9 +451,6 @@ class TransactionQueue(object):
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
- # XXX: what's this for?
- yield run_on_reactor()
-
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
@@ -464,6 +535,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+ except FederationDeniedError as e:
+ logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 52b2a717d2..4529d454af 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -49,7 +50,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state/%s/" % room_id
+ path = _create_path(PREFIX, "/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -71,7 +72,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state_ids/%s/" % room_id
+ path = _create_path(PREFIX, "/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -93,7 +94,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = PREFIX + "/event/%s/" % (event_id, )
+ path = _create_path(PREFIX, "/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@@ -119,7 +120,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = PREFIX + "/backfill/%s/" % (room_id,)
+ path = _create_path(PREFIX, "/backfill/%s/", room_id)
args = {
"v": event_tuples,
@@ -157,9 +158,11 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
+ path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
response = yield self.client.put_json(
transaction.destination,
- path=PREFIX + "/send/%s/" % transaction.transaction_id,
+ path=path,
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
@@ -177,7 +180,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
- path = PREFIX + "/query/%s" % query_type
+ path = _create_path(PREFIX, "/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@@ -212,6 +215,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if the remote destination
+ is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
@@ -219,7 +225,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+ path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -245,7 +251,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -258,7 +264,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -277,7 +283,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
- path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -319,7 +325,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+ path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@@ -332,7 +338,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@@ -344,7 +350,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
- path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@@ -406,7 +412,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = PREFIX + "/user/devices/" + user_id
+ path = _create_path(PREFIX, "/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@@ -456,7 +462,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
- path = PREFIX + "/get_missing_events/%s" % (room_id,)
+ path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@@ -471,3 +477,475 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ @log_function
+ def get_group_profile(self, destination, group_id, requester_user_id):
+ """Get a group profile
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_profile(self, destination, group_id, requester_user_id, content):
+ """Update a remote group profile
+
+ Args:
+ destination (str)
+ group_id (str)
+ requester_user_id (str)
+ content (dict): The new profile of the group
+ """
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_summary(self, destination, group_id, requester_user_id):
+ """Get a group summary
+ """
+ path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ """Get all rooms in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+ content):
+ """Add a room to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
+ config_key, content):
+ """Update room in group
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/room/%s/config/%s",
+ group_id, room_id, config_key,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ """Remove a room from a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users that have been invited to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def accept_group_invite(self, destination, group_id, user_id, content):
+ """Accept a group invite
+ """
+ path = _create_path(
+ PREFIX, "/groups/%s/users/%s/accept_invite",
+ group_id, user_id,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def join_group(self, destination, group_id, user_id, content):
+ """Attempts to join a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ """Invite a user to a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group_notification(self, destination, group_id, user_id, content):
+ """Sent by group server to inform a user's server that they have been
+ invited.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group(self, destination, group_id, requester_user_id,
+ user_id, content):
+ """Remove a user fron a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group_notification(self, destination, group_id, user_id,
+ content):
+ """Sent by group server to inform a user's server that they have been
+ kicked from the group.
+ """
+
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def renew_group_attestation(self, destination, group_id, user_id, content):
+ """Sent by either a group server or a user's server to periodically update
+ the attestations
+ """
+
+ path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id, content):
+ """Update a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id):
+ """Delete a room entry in a group summary
+ """
+ if category_id:
+ path = _create_path(
+ PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+ group_id, category_id, room_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_categories(self, destination, group_id, requester_user_id):
+ """Get all categories in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ """Get category info in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_category(self, destination, group_id, requester_user_id, category_id,
+ content):
+ """Update a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_category(self, destination, group_id, requester_user_id,
+ category_id):
+ """Delete a category in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_roles(self, destination, group_id, requester_user_id):
+ """Get all roles in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Get a roles info
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_role(self, destination, group_id, requester_user_id, role_id,
+ content):
+ """Update a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Delete a role in a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id, content):
+ """Update a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def set_group_join_policy(self, destination, group_id, requester_user_id,
+ content):
+ """Sets the join policy for a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+ return self.client.put_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id):
+ """Delete a users entry in a group
+ """
+ if role_id:
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+ group_id, role_id, user_id,
+ )
+ else:
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def bulk_get_publicised_groups(self, destination, user_ids):
+ """Get the groups a list of users are publicising
+ """
+
+ path = PREFIX + "/get_groups_publicised"
+
+ content = {"user_ids": user_ids}
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+
+def _create_path(prefix, path, *args):
+ """Creates a path from the prefix, path template and args. Ensures that
+ all args are url encoded.
+
+ Example:
+
+ _create_path(PREFIX, "/event/%s/", event_id)
+
+ Args:
+ prefix (str)
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a78f01e442..c9beca27c2 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,25 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import logging
+import re
+
from twisted.internet import defer
+import synapse
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
- parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
+ parse_integer_from_args,
+ parse_json_object_from_request,
+ parse_string_from_args,
)
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import preserve_fn
-from synapse.types import ThirdPartyInstanceID
-
-import functools
-import logging
-import re
-import synapse
-
logger = logging.getLogger(__name__)
@@ -81,6 +84,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -97,26 +101,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -125,11 +109,17 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
+ if (
+ self.federation_domain_whitelist is not None and
+ origin not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(origin)
+
if not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -144,11 +134,60 @@ class Authenticator(object):
# alive
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
- logger.info("Marking origin %r as up", origin)
- preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
+ run_in_background(self._reset_retry_timings, origin)
defer.returnValue(origin)
+ @defer.inlineCallbacks
+ def _reset_retry_timings(self, origin):
+ try:
+ logger.info("Marking origin %r as up", origin)
+ yield self.store.set_destination_retry_timings(origin, 0, 0)
+ except Exception:
+ logger.exception("Error resetting retry timings on %s", origin)
+
+
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+
+ # ensure that the origin is a valid server name
+ parse_and_validate_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
@@ -177,7 +216,7 @@ class BaseFederationServlet(object):
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
- except:
+ except Exception:
logger.exception("authenticate_request failed")
raise
@@ -270,7 +309,7 @@ class FederationSendServlet(BaseFederationServlet):
code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
- except:
+ except Exception:
logger.exception("on_incoming_transaction failed")
raise
@@ -347,7 +386,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_join_request(context, user_id)
+ content = yield self.handler.on_make_join_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
@@ -356,7 +397,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(context, user_id)
+ content = yield self.handler.on_make_leave_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
@@ -609,6 +652,549 @@ class FederationVersionServlet(BaseFederationServlet):
}))
+class FederationGroupsProfileServlet(BaseFederationServlet):
+ """Get/set the basic profile of a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_profile(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.update_group_profile(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+ PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_summary(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+ """Get the rooms in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+ """Add/remove room from group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+ """Update room config in group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+ """Get the users in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+ """Get the users that have been invited to a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_invited_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+ """Ask a group server to invite someone to the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+ """Accept an invitation from the group server
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.accept_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsJoinServlet(BaseFederationServlet):
+ """Attempt to join a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.join_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+ """Leave or kick a user from the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+ """A group server has invited a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "group_id doesn't match origin")
+
+ new_content = yield self.handler.on_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+ """A group server has removed a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.user_removed_from_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+ """A group or user's server renews their attestation
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ # We don't need to check auth here as we check the attestation signatures
+
+ new_content = yield self.handler.on_renew_attestation(
+ group_id, user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+ """Add/remove a room from the group summary, with optional category.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+ """Get all categories for a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+ """Add/remove/get a category in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_category(
+ group_id, requester_user_id, category_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.upsert_group_category(
+ group_id, requester_user_id, category_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_category(
+ group_id, requester_user_id, category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+ """Add/remove/get a role in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_role(
+ group_id, requester_user_id, role_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_role(
+ group_id, requester_user_id, role_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_role(
+ group_id, requester_user_id, role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+ """Add/remove a user from the group summary, with optional role.
+
+ Matches both:
+ - /groups/:group/summary/users/:user_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/get_groups_publicised$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ resp = yield self.handler.bulk_get_publicised_groups(
+ content["user_ids"], proxy=False,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+ """Sets whether a group is joinable without an invite or knock
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.set_group_join_policy(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -635,15 +1221,49 @@ FEDERATION_SERVLET_CLASSES = (
FederationVersionServlet,
)
+
ROOM_LIST_CLASSES = (
PublicRoomList,
)
+GROUP_SERVER_SERVLET_CLASSES = (
+ FederationGroupsProfileServlet,
+ FederationGroupsSummaryServlet,
+ FederationGroupsRoomsServlet,
+ FederationGroupsUsersServlet,
+ FederationGroupsInvitedUsersServlet,
+ FederationGroupsInviteServlet,
+ FederationGroupsAcceptInviteServlet,
+ FederationGroupsJoinServlet,
+ FederationGroupsRemoveUserServlet,
+ FederationGroupsSummaryRoomsServlet,
+ FederationGroupsCategoriesServlet,
+ FederationGroupsCategoryServlet,
+ FederationGroupsRolesServlet,
+ FederationGroupsRoleServlet,
+ FederationGroupsSummaryUsersServlet,
+ FederationGroupsAddRoomsServlet,
+ FederationGroupsAddRoomsConfigServlet,
+ FederationGroupsSettingJoinPolicyServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+ FederationGroupsLocalInviteServlet,
+ FederationGroupsRemoveLocalUserServlet,
+ FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+ FederationGroupsRenewAttestaionServlet,
+)
+
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_replication_layer(),
+ handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -656,3 +1276,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
+
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_attestation_renewer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 3f645acc43..bb1b3b13f7 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -17,10 +17,9 @@
server protocol.
"""
-from synapse.util.jsonobject import JsonEncodedObject
-
import logging
+from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
@@ -74,8 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "transaction_id",
- "destination",
"pdu_failures",
]
diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/groups/__init__.py
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
new file mode 100644
index 0000000000..47452700a8
--- /dev/null
+++ b/synapse/groups/attestations.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Attestations ensure that users and groups can't lie about their memberships.
+
+When a user joins a group the HS and GS swap attestations, which allow them
+both to independently prove to third parties their membership.These
+attestations have a validity period so need to be periodically renewed.
+
+If a user leaves (or gets kicked out of) a group, either side can still use
+their attestation to "prove" their membership, until the attestation expires.
+Therefore attestations shouldn't be relied on to prove membership in important
+cases, but can for less important situtations, e.g. showing a users membership
+of groups on their profile, showing flairs, etc.
+
+An attestation is a signed blob of json that looks like:
+
+ {
+ "user_id": "@foo:a.example.com",
+ "group_id": "+bar:b.example.com",
+ "valid_until_ms": 1507994728530,
+ "signatures":{"matrix.org":{"ed25519:auto":"..."}}
+ }
+"""
+
+import logging
+import random
+
+from signedjson.sign import sign_json
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import run_in_background
+
+logger = logging.getLogger(__name__)
+
+
+# Default validity duration for new attestations we create
+DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
+
+# We add some jitter to the validity duration of attestations so that if we
+# add lots of users at once we don't need to renew them all at once.
+# The jitter is a multiplier picked randomly between the first and second number
+DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
+
+# Start trying to update our attestations when they come this close to expiring
+UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
+
+
+class GroupAttestationSigning(object):
+ """Creates and verifies group attestations.
+ """
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.clock = hs.get_clock()
+ self.server_name = hs.hostname
+ self.signing_key = hs.config.signing_key[0]
+
+ @defer.inlineCallbacks
+ def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+ """Verifies that the given attestation matches the given parameters.
+
+ An optional server_name can be supplied to explicitly set which server's
+ signature is expected. Otherwise assumes that either the group_id or user_id
+ is local and uses the other's server as the one to check.
+ """
+
+ if not server_name:
+ if get_domain_from_id(group_id) == self.server_name:
+ server_name = get_domain_from_id(user_id)
+ elif get_domain_from_id(user_id) == self.server_name:
+ server_name = get_domain_from_id(group_id)
+ else:
+ raise Exception("Expected either group_id or user_id to be local")
+
+ if user_id != attestation["user_id"]:
+ raise SynapseError(400, "Attestation has incorrect user_id")
+
+ if group_id != attestation["group_id"]:
+ raise SynapseError(400, "Attestation has incorrect group_id")
+ valid_until_ms = attestation["valid_until_ms"]
+
+ # TODO: We also want to check that *new* attestations that people give
+ # us to store are valid for at least a little while.
+ if valid_until_ms < self.clock.time_msec():
+ raise SynapseError(400, "Attestation expired")
+
+ yield self.keyring.verify_json_for_server(server_name, attestation)
+
+ def create_attestation(self, group_id, user_id):
+ """Create an attestation for the group_id and user_id with default
+ validity length.
+ """
+ validity_period = DEFAULT_ATTESTATION_LENGTH_MS
+ validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+ valid_until_ms = int(self.clock.time_msec() + validity_period)
+
+ return sign_json({
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": valid_until_ms,
+ }, self.server_name, self.signing_key)
+
+
+class GroupAttestionRenewer(object):
+ """Responsible for sending and receiving attestation updates.
+ """
+
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.assestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.is_mine_id = hs.is_mine_id
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self._renew_attestations_loop = self.clock.looping_call(
+ self._renew_attestations, 30 * 60 * 1000,
+ )
+
+ @defer.inlineCallbacks
+ def on_renew_attestation(self, group_id, user_id, content):
+ """When a remote updates an attestation
+ """
+ attestation = content["attestation"]
+
+ if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
+ raise SynapseError(400, "Neither user not group are on this server")
+
+ yield self.attestations.verify_attestation(
+ attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+
+ yield self.store.update_remote_attestion(group_id, user_id, attestation)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def _renew_attestations(self):
+ """Called periodically to check if we need to update any of our attestations
+ """
+
+ now = self.clock.time_msec()
+
+ rows = yield self.store.get_attestations_need_renewals(
+ now + UPDATE_ATTESTATION_TIME_MS
+ )
+
+ @defer.inlineCallbacks
+ def _renew_attestation(group_id, user_id):
+ try:
+ if not self.is_mine_id(group_id):
+ destination = get_domain_from_id(group_id)
+ elif not self.is_mine_id(user_id):
+ destination = get_domain_from_id(user_id)
+ else:
+ logger.warn(
+ "Incorrectly trying to do attestations for user: %r in %r",
+ user_id, group_id,
+ )
+ yield self.store.remove_attestation_renewal(group_id, user_id)
+ return
+
+ attestation = self.attestations.create_attestation(group_id, user_id)
+
+ yield self.transport_client.renew_group_attestation(
+ destination, group_id, user_id,
+ content={"attestation": attestation},
+ )
+
+ yield self.store.update_attestation_renewal(
+ group_id, user_id, attestation
+ )
+ except Exception:
+ logger.exception("Error renewing attestation of %r in %r",
+ user_id, group_id)
+
+ for row in rows:
+ group_id = row["group_id"]
+ user_id = row["user_id"]
+
+ run_in_background(_renew_attestation, group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
new file mode 100644
index 0000000000..633c865ed8
--- /dev/null
+++ b/synapse/groups/groups_server.py
@@ -0,0 +1,953 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import string_types
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Federation admin APIs
+# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: Audit log for admins (profile updates, membership changes, users who tried
+# to join but were rejected, etc)
+# TODO: Flairs
+
+
+class GroupsServerHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.attestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ @defer.inlineCallbacks
+ def check_group_is_ours(self, group_id, requester_user_id,
+ and_exists=False, and_is_admin=None):
+ """Check that the group is ours, and optionally if it exists.
+
+ If group does exist then return group.
+
+ Args:
+ group_id (str)
+ and_exists (bool): whether to also check if group exists
+ and_is_admin (str): whether to also check if given str is a user_id
+ that is an admin
+ """
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Group not on this server")
+
+ group = yield self.store.get_group(group_id)
+ if and_exists and not group:
+ raise SynapseError(404, "Unknown group")
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ if group and not is_user_in_group and not group["is_public"]:
+ raise SynapseError(404, "Unknown group")
+
+ if and_is_admin:
+ is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ defer.returnValue(group)
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the summary for a group as seen by requester_user_id.
+
+ The group summary consists of the profile of the room, and a curated
+ list of users and rooms. These list *may* be organised by role/category.
+ The roles/categories are ordered, and so are the users/rooms within them.
+
+ A user/room may appear in multiple roles/categories.
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ profile = yield self.get_group_profile(group_id, requester_user_id)
+
+ users, roles = yield self.store.get_users_for_summary_by_role(
+ group_id, include_private=is_user_in_group,
+ )
+
+ # TODO: Add profiles to users
+
+ rooms, categories = yield self.store.get_rooms_for_summary_by_category(
+ group_id, include_private=is_user_in_group,
+ )
+
+ for room_entry in rooms:
+ room_id = room_entry["room_id"]
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+ entry = dict(entry) # so we don't change whats cached
+ entry.pop("room_id", None)
+
+ room_entry["profile"] = entry
+
+ rooms.sort(key=lambda e: e.get("order", 0))
+
+ for entry in users:
+ user_id = entry["user_id"]
+
+ if not self.is_mine_id(requester_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, user_id,
+ )
+
+ user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ entry.update(user_profile)
+
+ users.sort(key=lambda e: e.get("order", 0))
+
+ membership_info = yield self.store.get_users_membership_info_in_group(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue({
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ })
+
+ @defer.inlineCallbacks
+ def update_group_summary_room(self, group_id, requester_user_id,
+ room_id, category_id, content):
+ """Add/update a room to the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_room(self, group_id, requester_user_id,
+ room_id, category_id):
+ """Remove a room from the summary
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_room_from_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def set_group_join_policy(self, group_id, requester_user_id, content):
+ """Sets the group join policy.
+
+ Currently supported policies are:
+ - "invite": an invite must be received and accepted in order to join.
+ - "open": anyone can join.
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ join_policy = _parse_join_policy_from_contents(content)
+ if join_policy is None:
+ raise SynapseError(
+ 400, "No value specified for 'm.join_policy'"
+ )
+
+ yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id, requester_user_id):
+ """Get all categories in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ categories = yield self.store.get_group_categories(
+ group_id=group_id,
+ )
+ defer.returnValue({"categories": categories})
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, requester_user_id, category_id):
+ """Get a specific category in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_category(self, group_id, requester_user_id, category_id, content):
+ """Add/Update a group category
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_category(self, group_id, requester_user_id, category_id):
+ """Delete a group category
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id
+ )
+
+ yield self.store.remove_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id, requester_user_id):
+ """Get all roles in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ roles = yield self.store.get_group_roles(
+ group_id=group_id,
+ )
+ defer.returnValue({"roles": roles})
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, requester_user_id, role_id):
+ """Get a specific role in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_role(self, group_id, requester_user_id, role_id, content):
+ """Add/update a role in a group
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_role(self, group_id, requester_user_id, role_id):
+ """Remove role from group
+ """
+ yield self.check_group_is_ours(
+ group_id,
+ requester_user_id,
+ and_exists=True,
+ and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
+ content):
+ """Add/update a users entry in the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
+ """Remove a user from the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_user_from_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_profile(self, group_id, requester_user_id):
+ """Get the group profile as seen by requester_user_id
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id)
+
+ group = yield self.store.get_group(group_id)
+
+ if group:
+ cols = [
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public",
+ ]
+ group_description = {key: group[key] for key in cols}
+ group_description["is_openly_joinable"] = group["join_policy"] == "open"
+
+ defer.returnValue(group_description)
+ else:
+ raise SynapseError(404, "Unknown group")
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, requester_user_id, content):
+ """Update the group profile
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ profile = {}
+ for keyname in ("name", "avatar_url", "short_description",
+ "long_description"):
+ if keyname in content:
+ value = content[keyname]
+ if not isinstance(value, string_types):
+ raise SynapseError(400, "%r value is not a string" % (keyname,))
+ profile[keyname] = value
+
+ yield self.store.update_group_profile(group_id, profile)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get the users in group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for user_result in user_results:
+ g_user_id = user_result["user_id"]
+ is_public = user_result["is_public"]
+ is_privileged = user_result["is_admin"]
+
+ entry = {"user_id": g_user_id}
+
+ profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+ entry.update(profile)
+
+ entry["is_public"] = bool(is_public)
+ entry["is_privileged"] = bool(is_privileged)
+
+ if not self.is_mine_id(g_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, g_user_id,
+ )
+
+ chunk.append(entry)
+
+ # TODO: If admin add lists of users whose attestations have timed out
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_user_count_estimate": len(user_results),
+ })
+
+ @defer.inlineCallbacks
+ def get_invited_users_in_group(self, group_id, requester_user_id):
+ """Get the users that have been invited to a group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ if not is_user_in_group:
+ raise SynapseError(403, "User not in group")
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+
+ user_profiles = []
+
+ for user_id in invited_users:
+ user_profile = {
+ "user_id": user_id
+ }
+ try:
+ profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ user_profile.update(profile)
+ except Exception as e:
+ logger.warn("Error getting profile for %s: %s", user_id, e)
+ user_profiles.append(user_profile)
+
+ defer.returnValue({
+ "chunk": user_profiles,
+ "total_user_count_estimate": len(invited_users),
+ })
+
+ @defer.inlineCallbacks
+ def get_rooms_in_group(self, group_id, requester_user_id):
+ """Get the rooms in group as seen by requester_user_id
+
+ This returns rooms in order of decreasing number of joined users
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ room_results = yield self.store.get_rooms_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for room_result in room_results:
+ room_id = room_result["room_id"]
+
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+
+ if not entry:
+ continue
+
+ entry["is_public"] = bool(room_result["is_public"])
+
+ chunk.append(entry)
+
+ chunk.sort(key=lambda e: -e["num_joined_members"])
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_room_count_estimate": len(room_results),
+ })
+
+ @defer.inlineCallbacks
+ def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+ """Add room to group
+ """
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
+ content):
+ """Update room in group
+ """
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ if config_key == "m.visibility":
+ is_public = _parse_visibility_dict(content)
+
+ yield self.store.update_room_in_group_visibility(
+ group_id, room_id,
+ is_public=is_public,
+ )
+ else:
+ raise SynapseError(400, "Uknown config option")
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def remove_room_from_group(self, group_id, requester_user_id, room_id):
+ """Remove room from group
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ yield self.store.remove_room_from_group(group_id, room_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite_to_group(self, group_id, user_id, requester_user_id, content):
+ """Invite user to group
+ """
+
+ group = yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ # TODO: Check if user knocked
+ # TODO: Check if user is already invited
+
+ content = {
+ "profile": {
+ "name": group["name"],
+ "avatar_url": group["avatar_url"],
+ },
+ "inviter": requester_user_id,
+ }
+
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ res = yield groups_local.on_invite(group_id, user_id, content)
+ local_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content.update({
+ "attestation": local_attestation,
+ })
+
+ res = yield self.transport_client.invite_to_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, content
+ )
+
+ user_profile = res.get("user_profile", {})
+ yield self.store.add_remote_profile_cache(
+ user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ if res["state"] == "join":
+ if not self.hs.is_mine_id(user_id):
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=False, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+ elif res["state"] == "invite":
+ yield self.store.add_group_invite(
+ group_id, user_id,
+ )
+ defer.returnValue({
+ "state": "invite"
+ })
+ elif res["state"] == "reject":
+ defer.returnValue({
+ "state": "reject"
+ })
+ else:
+ raise SynapseError(502, "Unknown state returned by HS")
+
+ @defer.inlineCallbacks
+ def _add_user(self, group_id, user_id, content):
+ """Add a user to a group based on a content dict.
+
+ See accept_invite, join_group.
+ """
+ if not self.hs.is_mine_id(user_id):
+ local_attestation = self.attestations.create_attestation(
+ group_id, user_id,
+ )
+
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ local_attestation = None
+ remote_attestation = None
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=is_public,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ defer.returnValue(local_attestation)
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, requester_user_id, content):
+ """User tries to accept an invite to the group.
+
+ This is different from them asking to join, and so should error if no
+ invite exists (and they're not a member of the group)
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_invited = yield self.store.is_user_invited_to_local_group(
+ group_id, requester_user_id,
+ )
+ if not is_invited:
+ raise SynapseError(403, "User not invited to group")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, requester_user_id, content):
+ """User tries to join the group.
+
+ This will error if the group requires an invite/knock to join
+ """
+
+ group_info = yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True
+ )
+ if group_info['join_policy'] != "open":
+ raise SynapseError(403, "Group is not publicly joinable")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def knock(self, group_id, requester_user_id, content):
+ """A user requests becoming a member of the group
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def accept_knock(self, group_id, requester_user_id, content):
+ """Accept a users knock to the room.
+
+ Errors if the user hasn't knocked, rather than inviting them.
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from the group; either a user is leaving or an admin
+ kicked them.
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_kick = False
+ if requester_user_id != user_id:
+ is_admin = yield self.store.is_user_admin_in_group(
+ group_id, requester_user_id
+ )
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ is_kick = True
+
+ yield self.store.remove_user_from_group(
+ group_id, user_id,
+ )
+
+ if is_kick:
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ yield groups_local.user_removed_from_group(group_id, user_id, {})
+ else:
+ yield self.transport_client.remove_user_from_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, {}
+ )
+
+ if not self.hs.is_mine_id(user_id):
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, requester_user_id, content):
+ group = yield self.check_group_is_ours(group_id, requester_user_id)
+
+ logger.info("Attempting to create group with ID: %r", group_id)
+
+ # parsing the id into a GroupID validates it.
+ group_id_obj = GroupID.from_string(group_id)
+
+ if group:
+ raise SynapseError(400, "Group already exists")
+
+ is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
+ if not is_admin:
+ if not self.hs.config.enable_group_creation:
+ raise SynapseError(
+ 403, "Only a server admin can create groups on this server",
+ )
+ localpart = group_id_obj.localpart
+ if not localpart.startswith(self.hs.config.group_creation_prefix):
+ raise SynapseError(
+ 400,
+ "Can only create groups with prefix %r on this server" % (
+ self.hs.config.group_creation_prefix,
+ ),
+ )
+
+ profile = content.get("profile", {})
+ name = profile.get("name")
+ avatar_url = profile.get("avatar_url")
+ short_description = profile.get("short_description")
+ long_description = profile.get("long_description")
+ user_profile = content.get("user_profile", {})
+
+ yield self.store.create_group(
+ group_id,
+ requester_user_id,
+ name=name,
+ avatar_url=avatar_url,
+ short_description=short_description,
+ long_description=long_description,
+ )
+
+ if not self.hs.is_mine_id(requester_user_id):
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=requester_user_id,
+ group_id=group_id,
+ )
+
+ local_attestation = self.attestations.create_attestation(
+ group_id,
+ requester_user_id,
+ )
+ else:
+ local_attestation = None
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, requester_user_id,
+ is_admin=True,
+ is_public=True, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ if not self.hs.is_mine_id(requester_user_id):
+ yield self.store.add_remote_profile_cache(
+ requester_user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ defer.returnValue({
+ "group_id": group_id,
+ })
+
+
+def _parse_join_policy_from_contents(content):
+ """Given a content for a request, return the specified join policy or None
+ """
+
+ join_policy_dict = content.get("m.join_policy")
+ if join_policy_dict:
+ return _parse_join_policy_dict(join_policy_dict)
+ else:
+ return None
+
+
+def _parse_join_policy_dict(join_policy_dict):
+ """Given a dict for the "m.join_policy" config return the join policy specified
+ """
+ join_policy_type = join_policy_dict.get("type")
+ if not join_policy_type:
+ return "invite"
+
+ if join_policy_type not in ("invite", "open"):
+ raise SynapseError(
+ 400, "Synapse only supports 'invite'/'open' join rule"
+ )
+ return join_policy_type
+
+
+def _parse_visibility_from_contents(content):
+ """Given a content for a request parse out whether the entity should be
+ public or not
+ """
+
+ visibility = content.get("m.visibility")
+ if visibility:
+ return _parse_visibility_dict(visibility)
+ else:
+ is_public = True
+
+ return is_public
+
+
+def _parse_visibility_dict(visibility):
+ """Given a dict for the "m.visibility" config return if the entity should
+ be public or not
+ """
+ vis_type = visibility.get("type")
+ if not vis_type:
+ return True
+
+ if vis_type not in ("public", "private"):
+ raise SynapseError(
+ 400, "Synapse only supports 'public'/'private' visibility"
+ )
+ return vis_type == "public"
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 5ad408f549..413425fed1 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -13,17 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .register import RegistrationHandler
-from .room import (
- RoomCreationHandler, RoomContextHandler,
-)
-from .room_member import RoomMemberHandler
-from .message import MessageHandler
-from .federation import FederationHandler
-from .profile import ProfileHandler
-from .directory import DirectoryHandler
from .admin import AdminHandler
+from .directory import DirectoryHandler
+from .federation import FederationHandler
from .identity import IdentityHandler
+from .register import RegistrationHandler
from .search import SearchHandler
@@ -48,13 +42,8 @@ class Handlers(object):
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
- self.message_handler = MessageHandler(hs)
- self.room_creation_handler = RoomCreationHandler(hs)
- self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs)
- self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
- self.room_context_handler = RoomContextHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index faa5609c0c..704181d2d3 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -18,11 +18,10 @@ import logging
from twisted.internet import defer
import synapse.types
-from synapse.api.constants import Membership, EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
-
logger = logging.getLogger(__name__)
@@ -113,15 +112,16 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
+ current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events(
- context.current_state_ids.values()
+ list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
event.room_id
)
- current_state = current_state.values()
+ current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
@@ -158,7 +158,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(
target_user, is_guest=True)
- handler = self.hs.get_handlers().room_member_handler
+ handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
target_user,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f36b358b45..5d629126fc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from ._base import BaseHandler
-import logging
-
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..ee41aed69e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -13,16 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six import itervalues
+
+from prometheus_client import Counter
+
from twisted.internet import defer
+import synapse
from synapse.api.constants import EventTypes
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
logger = logging.getLogger(__name__)
+events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
+
def log_failure(failure):
logger.error(
@@ -70,21 +78,25 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
- upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
- upper_bound, limit
+ self.current_max, limit
)
if not events:
break
+ events_by_room = {}
for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ @defer.inlineCallbacks
+ def handle_event(event):
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
- continue # no services need notifying
+ return # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
@@ -95,19 +107,39 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
- self.scheduler.start().addErrback(log_failure)
+ def start_scheduler():
+ return self.scheduler.start().addErrback(log_failure)
+ run_as_background_process("as_scheduler", start_scheduler)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
- preserve_fn(self.scheduler.submit_event_for_as)(
- service, event
- )
+ self.scheduler.submit_event_for_as(service, event)
+
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ], consumeErrors=True))
yield self.store.set_appservice_last_pos(upper_bound)
- if len(events) < limit:
- break
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_positions.labels(
+ "appservice_sender").set(upper_bound)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_lag.labels(
+ "appservice_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "appservice_sender").set(ts)
finally:
self.is_processing = False
@@ -163,8 +195,11 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield preserve_context_over_deferred(defer.DeferredList([
- preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ results = yield make_deferred_yieldable(defer.DeferredList([
+ run_in_background(
+ self.appservice_api.query_3pe,
+ service, kind, protocol, fields,
+ )
for service in services
], consumeErrors=True))
@@ -225,11 +260,15 @@ class ApplicationServicesHandler(object):
event based on the service regex.
"""
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- yield s.is_interested(event, self.store)
- )
- ]
+
+ # we can't use a list comprehension here. Since python 3, list
+ # comprehensions use a generator internally. This means you can't yield
+ # inside of a list comprehension anymore.
+ interested_list = []
+ for s in services:
+ if (yield s.is_interested(event, self.store)):
+ interested_list.append(s)
+
defer.returnValue(interested_list)
def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b00446bec0..402e44cdef 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,24 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from ._base import BaseHandler
-from synapse.api.constants import LoginType
-from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
-from synapse.util.async import run_on_reactor
-from synapse.util.caches.expiringcache import ExpiringCache
-
-from twisted.web.client import PartialDownloadError
-
import logging
+
+import attr
import bcrypt
import pymacaroons
-import simplejson
+from canonicaljson import json
+
+from twisted.internet import defer, threads
+from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
+from synapse.api.constants import LoginType
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InteractiveAuthIncompleteError,
+ LoginError,
+ StoreError,
+ SynapseError,
+)
+from synapse.module_api import ModuleApi
+from synapse.types import UserID
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -46,7 +54,6 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
- LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
@@ -63,10 +70,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True,
)
- account_handler = _AccountHandler(
- hs, check_user_exists=self.check_user_exists
- )
-
+ account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -75,39 +79,120 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._password_enabled = hs.config.password_enabled
+
+ # we keep this as a list despite the O(N^2) implication so that we can
+ # keep PASSWORD first and avoid confusing clients which pick the first
+ # type in the list. (NB that the spec doesn't require us to do so and
+ # clients which favour types that they don't understand over those that
+ # they do are technically broken)
+ login_types = []
+ if self._password_enabled:
+ login_types.append(LoginType.PASSWORD)
+ for provider in self.password_providers:
+ if hasattr(provider, "get_supported_login_types"):
+ for t in provider.get_supported_login_types().keys():
+ if t not in login_types:
+ login_types.append(t)
+ self._supported_login_types = login_types
+
+ @defer.inlineCallbacks
+ def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ """
+ Checks that the user is who they claim to be, via a UI auth.
+
+ This is used for things like device deletion and password reset where
+ the user already has a valid access token, but we want to double-check
+ that it isn't stolen by re-authenticating them.
+
+ Args:
+ requester (Requester): The user, as given by the access token
+
+ request_body (dict): The body of the request sent by the client
+
+ clientip (str): The IP address of the client.
+
+ Returns:
+ defer.Deferred[dict]: the parameters for this request (which may
+ have been given only in a previous call).
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ any of the permitted login flows
+
+ AuthError if the client has completed a login flow, and it gives
+ a different user to `requester`
+ """
+
+ # build a list of supported flows
+ flows = [
+ [login_type] for login_type in self._supported_login_types
+ ]
+
+ result, params, _ = yield self.check_auth(
+ flows, request_body, clientip,
+ )
+
+ # find the completed login type
+ for login_type in self._supported_login_types:
+ if login_type not in result:
+ continue
+
+ user_id = result[login_type]
+ break
+ else:
+ # this can't happen
+ raise Exception(
+ "check_auth returned True but no successful login type",
+ )
+
+ # check that the UI auth matched the access token
+ if user_id != requester.user.to_string():
+ raise AuthError(403, "Invalid auth")
+
+ defer.returnValue(params)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
- protocol and handles the login flow.
+ protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
+ If no auth flows have been completed successfully, raises an
+ InteractiveAuthIncompleteError. To handle this, you can use
+ synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+ decorator.
+
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
+
clientip (str): The IP address of the client.
+
Returns:
- A tuple of (authed, dict, dict, session_id) where authed is true if
- the client has successfully completed an auth flow. If it is true
- the first dict contains the authenticated credentials of each stage.
+ defer.Deferred[dict, dict, str]: a deferred tuple of
+ (creds, params, session_id).
- If authed is false, the first dictionary is the server response to
- the login request and should be passed back to the client.
+ 'creds' contains the authenticated credentials of each stage.
- In either case, the second dict contains the parameters for this
- request (which may have been given only in a previous call).
+ 'params' contains the parameters for this request (which may
+ have been given only in a previous call).
- session_id is the ID of this session, either passed in by the client
- or assigned by the call to check_auth
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ all the stages in any of the permitted flows.
"""
authdict = None
@@ -135,11 +220,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
- defer.returnValue(
- (
- False, self._auth_dict_for_flows(flows, session),
- clientdict, session['id']
- )
+ raise InteractiveAuthIncompleteError(
+ self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@@ -150,14 +232,12 @@ class AuthHandler(BaseHandler):
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
- if login_type not in self.checkers:
- raise LoginError(400, "", Codes.UNRECOGNIZED)
try:
- result = yield self.checkers[login_type](authdict, clientip)
+ result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
- except LoginError, e:
+ except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
@@ -166,14 +246,14 @@ class AuthHandler(BaseHandler):
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
- raise e
+ raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows:
- if len(set(f) - set(creds.keys())) == 0:
+ if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
# the keys (confusingly, clientdict may contain a password
@@ -181,14 +261,16 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, clientdict.keys()
+ creds, list(clientdict)
)
- defer.returnValue((True, creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = creds.keys()
+ ret['completed'] = list(creds)
ret.update(errordict)
- defer.returnValue((False, ret, clientdict, session['id']))
+ raise InteractiveAuthIncompleteError(
+ ret,
+ )
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -260,16 +342,37 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- def _check_password_auth(self, authdict, _):
- if "user" not in authdict or "password" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ @defer.inlineCallbacks
+ def _check_auth_dict(self, authdict, clientip):
+ """Attempt to validate the auth dict provided by a client
- user_id = authdict["user"]
- password = authdict["password"]
- if not user_id.startswith('@'):
- user_id = UserID.create(user_id, self.hs.hostname).to_string()
+ Args:
+ authdict (object): auth dict provided by the client
+ clientip (str): IP address of the client
+
+ Returns:
+ Deferred: result of the stage verification.
+
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ login_type = authdict['type']
+ checker = self.checkers.get(login_type)
+ if checker is not None:
+ res = yield checker(authdict, clientip)
+ defer.returnValue(res)
+
+ # build a v1-login-style dict out of the authdict and fall back to the
+ # v1 code
+ user_id = authdict.get("user")
- return self._check_password(user_id, password)
+ if user_id is None:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+ (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+ defer.returnValue(canonical_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -303,7 +406,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = simplejson.loads(data)
+ resp_body = json.loads(data)
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
@@ -324,15 +427,11 @@ class AuthHandler(BaseHandler):
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
- @defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
- yield run_on_reactor()
- defer.returnValue(True)
+ return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
- yield run_on_reactor()
-
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@@ -398,26 +497,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- def validate_password_login(self, user_id, password):
- """
- Authenticates the user with their username and password.
-
- Used only by the v1 login API.
-
- Args:
- user_id (str): complete @user:id
- password (str): Password
- Returns:
- defer.Deferred: (str) canonical user id
- Raises:
- StoreError if there was a problem accessing the database
- LoginError if there was an authentication problem.
- """
- return self._check_password(user_id, password)
-
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@@ -431,13 +512,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- initial_display_name (str): display name to associate with the
- device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
- LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@@ -447,9 +525,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ try:
+ yield self.store.get_device(user_id, device_id)
+ except StoreError:
+ yield self.store.delete_access_token(access_token)
+ raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@@ -501,29 +581,115 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
+ def get_supported_login_types(self):
+ """Get a the login types supported for the /login API
+
+ By default this is just 'm.login.password' (unless password_enabled is
+ False in the config file), but password auth providers can provide
+ other login types.
+
+ Returns:
+ Iterable[str]: login types
+ """
+ return self._supported_login_types
+
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Authenticate a user against the LDAP and local databases.
+ def validate_login(self, username, login_submission):
+ """Authenticates the user for the /login API
- user_id is checked case insensitively against the local database, but
- will throw if there are multiple inexact matches.
+ Also used by the user-interactive auth flow to validate
+ m.login.password auth types.
Args:
- user_id (str): complete @user:id
+ username (str): username supplied by the user
+ login_submission (dict): the whole of the login submission
+ (including 'type' and other relevant fields)
Returns:
- (str) the canonical_user_id
+ Deferred[str, func]: canonical user id, and optional callback
+ to be called once the access token and device id are issued
Raises:
- LoginError if login fails
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
"""
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(
+ username, self.hs.hostname
+ ).to_string()
+
+ login_type = login_submission.get("type")
+ known_login_type = False
+
+ # special case to check for "password" for the check_password interface
+ # for the auth providers
+ password = login_submission.get("password")
+ if login_type == LoginType.PASSWORD:
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if not password:
+ raise SynapseError(400, "Missing parameter: password")
+
for provider in self.password_providers:
- is_valid = yield provider.check_password(user_id, password)
- if is_valid:
- defer.returnValue(user_id)
+ if (hasattr(provider, "check_password")
+ and login_type == LoginType.PASSWORD):
+ known_login_type = True
+ is_valid = yield provider.check_password(
+ qualified_user_id, password,
+ )
+ if is_valid:
+ defer.returnValue((qualified_user_id, None))
+
+ if (not hasattr(provider, "get_supported_login_types")
+ or not hasattr(provider, "check_auth")):
+ # this password provider doesn't understand custom login types
+ continue
+
+ supported_login_types = provider.get_supported_login_types()
+ if login_type not in supported_login_types:
+ # this password provider doesn't understand this login type
+ continue
+
+ known_login_type = True
+ login_fields = supported_login_types[login_type]
+
+ missing_fields = []
+ login_dict = {}
+ for f in login_fields:
+ if f not in login_submission:
+ missing_fields.append(f)
+ else:
+ login_dict[f] = login_submission[f]
+ if missing_fields:
+ raise SynapseError(
+ 400, "Missing parameters for login type %s: %s" % (
+ login_type,
+ missing_fields,
+ ),
+ )
+
+ result = yield provider.check_auth(
+ username, login_type, login_dict,
+ )
+ if result:
+ if isinstance(result, str):
+ result = (result, None)
+ defer.returnValue(result)
+
+ if login_type == LoginType.PASSWORD:
+ known_login_type = True
+
+ canonical_user_id = yield self._check_local_password(
+ qualified_user_id, password,
+ )
- canonical_user_id = yield self._check_local_password(user_id, password)
+ if canonical_user_id:
+ defer.returnValue((canonical_user_id, None))
- if canonical_user_id:
- defer.returnValue(canonical_user_id)
+ if not known_login_type:
+ raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@@ -549,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
- result = self.validate_hash(password, password_hash)
+ result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
@@ -573,22 +739,65 @@ class AuthHandler(BaseHandler):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword, requester=None):
- password_hash = self.hash(newpassword)
+ def delete_access_token(self, access_token):
+ """Invalidate a single access token
- except_access_token_id = requester.access_token_id if requester else None
+ Args:
+ access_token (str): access token to be deleted
- try:
- yield self.store.user_set_password_hash(user_id, password_hash)
- except StoreError as e:
- if e.code == 404:
- raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
- raise e
- yield self.store.user_delete_access_tokens(
- user_id, except_access_token_id
+ Returns:
+ Deferred
+ """
+ user_info = yield self.auth.get_user_by_access_token(access_token)
+ yield self.store.delete_access_token(access_token)
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ yield provider.on_logged_out(
+ user_id=str(user_info["user"]),
+ device_id=user_info["device_id"],
+ access_token=access_token,
+ )
+
+ # delete pushers associated with this access token
+ if user_info["token_id"] is not None:
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ str(user_info["user"]), (user_info["token_id"], )
+ )
+
+ @defer.inlineCallbacks
+ def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+ device_id=None):
+ """Invalidate access tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str|None): access_token ID which should *not* be
+ deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ Returns:
+ Deferred
+ """
+ tokens_and_devices = yield self.store.user_delete_access_tokens(
+ user_id, except_token_id=except_token_id, device_id=device_id,
)
- yield self.hs.get_pusherpool().remove_pushers_by_user(
- user_id, except_access_token_id
+
+ # see if any of our auth providers want to know about this
+ for provider in self.password_providers:
+ if hasattr(provider, "on_logged_out"):
+ for token, token_id, device_id in tokens_and_devices:
+ yield provider.on_logged_out(
+ user_id=user_id,
+ device_id=device_id,
+ access_token=token,
+ )
+
+ # delete pushers associated with the access tokens
+ yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks
@@ -616,6 +825,15 @@ class AuthHandler(BaseHandler):
if medium == 'email':
address = address.lower()
+ identity_handler = self.hs.get_handlers().identity_handler
+ yield identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': medium,
+ 'address': address,
+ },
+ )
+
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
@@ -634,10 +852,17 @@ class AuthHandler(BaseHandler):
password (str): Password to hash.
Returns:
- Hashed password (str).
+ Deferred(str): Hashed password.
"""
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- bcrypt.gensalt(self.bcrypt_rounds))
+ def _do_hash():
+ return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
+
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
+ ),
+ )
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -647,20 +872,31 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value.
Returns:
- Whether self.hash(password) == stored_hash (bool).
+ Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
+ def _do_validate_hash():
+ return bcrypt.checkpw(
+ password.encode('utf8') + self.hs.config.password_pepper,
+ stored_hash.encode('utf8')
+ )
+
if stored_hash:
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- stored_hash.encode('utf8')) == stored_hash
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(
+ self.hs.get_reactor(),
+ self.hs.get_reactor().getThreadPool(),
+ _do_validate_hash,
+ ),
+ )
else:
- return False
+ return defer.succeed(False)
-class MacaroonGeneartor(object):
- def __init__(self, hs):
- self.clock = hs.get_clock()
- self.server_name = hs.config.server_name
- self.macaroon_secret_key = hs.config.macaroon_secret_key
+@attr.s
+class MacaroonGenerator(object):
+
+ hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
@@ -678,7 +914,7 @@ class MacaroonGeneartor(object):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
- now = self.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
@@ -690,36 +926,9 @@ class MacaroonGeneartor(object):
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
- location=self.server_name,
+ location=self.hs.config.server_name,
identifier="key",
- key=self.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
-
-
-class _AccountHandler(object):
- """A proxy object that gets passed to password auth providers so they
- can register new users etc if necessary.
- """
- def __init__(self, hs, check_user_exists):
- self.hs = hs
-
- self._check_user_exists = check_user_exists
-
- def check_user_exists(self, user_id):
- """Check if user exissts.
-
- Returns:
- Deferred(bool)
- """
- return self._check_user_exists(user_id)
-
- def register(self, localpart):
- """Registers a new user with given localpart
-
- Returns:
- Deferred: a 2-tuple of (user_id, access_token)
- """
- reg = self.hs.get_handlers().registration_handler
- return reg.register(localpart=localpart)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
new file mode 100644
index 0000000000..b3c5a9ee64
--- /dev/null
+++ b/synapse/handlers/deactivate_account.py
@@ -0,0 +1,163 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017, 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, create_requester
+from synapse.util.logcontext import run_in_background
+
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class DeactivateAccountHandler(BaseHandler):
+ """Handler which deals with deactivating user accounts."""
+ def __init__(self, hs):
+ super(DeactivateAccountHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+ self._room_member_handler = hs.get_room_member_handler()
+ self._identity_handler = hs.get_handlers().identity_handler
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ # Flag that indicates whether the process to part users from rooms is running
+ self._user_parter_running = False
+
+ # Start the user parter loop so it can resume parting users from rooms where
+ # it left off (if it has work left to do).
+ hs.get_reactor().callWhenRunning(self._start_user_parting)
+
+ @defer.inlineCallbacks
+ def deactivate_account(self, user_id, erase_data):
+ """Deactivate a user's account
+
+ Args:
+ user_id (str): ID of user to be deactivated
+ erase_data (bool): whether to GDPR-erase the user's data
+
+ Returns:
+ Deferred
+ """
+ # FIXME: Theoretically there is a race here wherein user resets
+ # password using threepid.
+
+ # delete threepids first. We remove these from the IS so if this fails,
+ # leave the user still active so they can try again.
+ # Ideally we would prevent password resets and then do this in the
+ # background thread.
+ threepids = yield self.store.user_get_threepids(user_id)
+ for threepid in threepids:
+ try:
+ yield self._identity_handler.unbind_threepid(
+ user_id,
+ {
+ 'medium': threepid['medium'],
+ 'address': threepid['address'],
+ },
+ )
+ except Exception:
+ # Do we want this to be a fatal error or should we carry on?
+ logger.exception("Failed to remove threepid from ID server")
+ raise SynapseError(400, "Failed to remove threepid from ID server")
+ yield self.store.user_delete_threepid(
+ user_id, threepid['medium'], threepid['address'],
+ )
+
+ # delete any devices belonging to the user, which will also
+ # delete corresponding access tokens.
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+ # then delete any remaining access tokens which weren't associated with
+ # a device.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
+
+ yield self.store.user_set_password_hash(user_id, None)
+
+ # Add the user to a table of users pending deactivation (ie.
+ # removal from all the rooms they're a member of)
+ yield self.store.add_user_pending_deactivation(user_id)
+
+ # delete from user directory
+ yield self.user_directory_handler.handle_user_deactivated(user_id)
+
+ # Mark the user as erased, if they asked for that
+ if erase_data:
+ logger.info("Marking %s as erased", user_id)
+ yield self.store.mark_user_erased(user_id)
+
+ # Now start the process that goes through that list and
+ # parts users from rooms (if it isn't already running)
+ self._start_user_parting()
+
+ def _start_user_parting(self):
+ """
+ Start the process that goes through the table of users
+ pending deactivation, if it isn't already running.
+
+ Returns:
+ None
+ """
+ if not self._user_parter_running:
+ run_in_background(self._user_parter_loop)
+
+ @defer.inlineCallbacks
+ def _user_parter_loop(self):
+ """Loop that parts deactivated users from rooms
+
+ Returns:
+ None
+ """
+ self._user_parter_running = True
+ logger.info("Starting user parter")
+ try:
+ while True:
+ user_id = yield self.store.get_user_pending_deactivation()
+ if user_id is None:
+ break
+ logger.info("User parter parting %r", user_id)
+ yield self._part_user(user_id)
+ yield self.store.del_user_pending_deactivation(user_id)
+ logger.info("User parter finished parting %r", user_id)
+ logger.info("User parter finished: stopping")
+ finally:
+ self._user_parter_running = False
+
+ @defer.inlineCallbacks
+ def _part_user(self, user_id):
+ """Causes the given user_id to leave all the rooms they're joined to
+
+ Returns:
+ None
+ """
+ user = UserID.from_string(user_id)
+
+ rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ for room_id in rooms_for_user:
+ logger.info("User parter parting %r from %r", user_id, room_id)
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room_id,
+ "leave",
+ ratelimit=False,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to part user %r from room %r: ignoring and continuing",
+ user_id, room_id,
+ )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ed60d494ff..2d44f15da3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -12,18 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
+from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id, RoomStreamToken
-from twisted.internet import defer
-from ._base import BaseHandler
+from synapse.util.retryutils import NotRetryingDestination
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,15 +39,17 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
- self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self)
- self.federation.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update,
)
- self.federation.register_query_handler(
+ federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@@ -109,7 +116,7 @@ class DeviceHandler(BaseHandler):
user_id, device_id=None
)
- devices = device_map.values()
+ devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
@@ -152,16 +159,15 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_device(user_id, device_id)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
else:
raise
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@@ -171,12 +177,30 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
+ def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ """Delete all of the user's devices
+
+ Args:
+ user_id (str):
+ except_device_id (str|None): optional device id which should not
+ be deleted
+
+ Returns:
+ defer.Deferred:
+ """
+ device_map = yield self.store.get_devices_by_user(user_id)
+ device_ids = list(device_map)
+ if except_device_id is not None:
+ device_ids = [d for d in device_ids if d != except_device_id]
+ yield self.delete_devices(user_id, device_ids)
+
+ @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
- device_ids (str): The list of device IDs to delete
+ device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:
@@ -184,7 +208,7 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_devices(user_id, device_ids)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
@@ -194,9 +218,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -224,7 +247,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
else:
@@ -270,6 +293,8 @@ class DeviceHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
@@ -280,11 +305,30 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ member_events = yield self.store.get_membership_changes_for_user(
+ user_id, from_token.room_key, now_token.room_key
+ )
+ rooms_changed.update(event.room_id for event in member_events)
+
stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key).stream
+ from_token.room_key
+ ).stream
possibly_changed = set(changed)
+ possibly_left = set()
for room_id in rooms_changed:
+ current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+ # The user may have left the room
+ # TODO: Check if they actually did or if we were just invited.
+ if room_id not in room_ids:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_left.add(state_key)
+ continue
+
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
@@ -295,44 +339,69 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room
event_ids = []
- current_state_ids = yield self.store.get_current_state_ids(room_id)
-
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
+ current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+ if not current_member_id:
+ continue
+
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ # Check if we've joined the room? If so we just blindly add all the users to
+ # the "possibly changed" users.
+ for state_dict in itervalues(prev_state_ids):
+ member_event = state_dict.get((EventTypes.Member, user_id), None)
+ if not member_event or member_event != current_member_id:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_changed.add(state_key)
+ break
+
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.values():
+ for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
- possibly_changed.add(state_key)
+ if state_key != user_id:
+ possibly_changed.add(state_key)
break
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
+ if possibly_changed or possibly_left:
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
- # Take the intersection of the users whose devices may have changed
- # and those that actually still share a room with the user
- defer.returnValue(users_who_share_room & possibly_changed)
+ # Take the intersection of the users whose devices may have changed
+ # and those that actually still share a room with the user
+ possibly_joined = possibly_changed & users_who_share_room
+ possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+ else:
+ possibly_joined = []
+ possibly_left = []
+
+ defer.returnValue({
+ "changed": list(possibly_joined),
+ "left": list(possibly_left),
+ })
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
@@ -366,7 +435,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.device_handler = device_handler
@@ -450,6 +519,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
@@ -467,7 +539,7 @@ class DeviceListEduUpdater(object):
yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
- # change (becuase of the single prev_id matching the current cache)
+ # change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..2e2e5261de 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,10 +17,10 @@ import logging
from twisted.internet import defer
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
-
logger = logging.getLogger(__name__)
@@ -33,10 +33,10 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@@ -52,6 +52,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise SynapseError(400, "Not a user here")
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +83,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 943554ce98..ef866da1b6 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,15 +14,16 @@
# limitations under the License.
+import logging
+import string
+
from twisted.internet import defer
-from ._base import BaseHandler
-from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError
from synapse.types import RoomAlias, UserID, get_domain_from_id
-import logging
-import string
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,12 +35,15 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query
)
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services
@@ -73,6 +77,11 @@ class DirectoryHandler(BaseHandler):
# association creation for human users
# TODO(erikj): Do user auth.
+ if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ raise SynapseError(
+ 403, "This user is not permitted to create this alias",
+ )
+
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
@@ -242,8 +251,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases,
@@ -265,8 +273,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
room_id (str)
visibility (str): "public" or "private"
"""
+ if not self.spam_checker.user_may_publish_room(
+ requester.user.to_string(), room_id
+ ):
+ raise AuthError(
+ 403,
+ "This user is not permitted to publish rooms to the room list"
+ )
+
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..5816bf8b4f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import ujson as json
import logging
-from canonicaljson import encode_canonical_json
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -30,15 +33,15 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
- self.is_mine_id = hs.is_mine_id
+ self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- self.federation.register_query_handler(
+ hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -70,12 +73,13 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
- # Firt get local devices.
+ # First get local devices.
failures = {}
results = {}
if local_query:
@@ -88,7 +92,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in remote_queries.iteritems():
+ for user_id, device_ids in iteritems(remote_queries):
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -99,9 +103,9 @@ class E2eKeysHandler(object):
query_list
)
)
- for user_id, devices in remote_results.iteritems():
+ for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.iteritems():
+ for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -131,24 +135,13 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(do_remote_query)(destination)
+ run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
- ]))
+ ], consumeErrors=True))
defer.returnValue({
"device_keys": results, "failures": failures,
@@ -170,7 +163,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
- if not self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +207,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
@@ -243,32 +238,21 @@ class E2eKeysHandler(object):
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
+ run_in_background(claim_client_keys, destination)
for destination in remote_queries
- ]))
+ ], consumeErrors=True))
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -353,6 +337,31 @@ class E2eKeysHandler(object):
)
+def _exception_to_failure(e):
+ if isinstance(e, CodeMessageException):
+ return {
+ "status": e.code, "message": e.message,
+ }
+
+ if isinstance(e, NotRetryingDestination):
+ return {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ if isinstance(e, FederationDeniedError):
+ return {
+ "status": 403, "message": "Federation Denied",
+ }
+
+ # include ConnectionRefused and other errors
+ #
+ # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
+ # give a string for e.message, which json then fails to serialize.
+ return {
+ "status": 503, "message": str(e.message),
+ }
+
+
def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d3685fb12a..c3f2d7feff 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -13,20 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import random
+
from twisted.internet import defer
-from synapse.util.logutils import log_function
-from synapse.types import UserID
-from synapse.events.utils import serialize_event
-from synapse.api.constants import Membership, EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.events.utils import serialize_event
+from synapse.types import UserID
+from synapse.util.logutils import log_function
from ._base import BaseHandler
-import logging
-import random
-
-
logger = logging.getLogger(__name__)
@@ -48,6 +47,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
@log_function
@@ -58,6 +58,10 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
+
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(auth_user_id)
+
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 483cb8eac6..145c1a21d4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,41 +15,46 @@
# limitations under the License.
"""Contains handlers for federation events."""
-import synapse.util.logcontext
+
+import itertools
+import logging
+import sys
+
+import six
+from six import iteritems, itervalues
+from six.moves import http_client, zip
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
-from ._base import BaseHandler
+from twisted.internet import defer
-from synapse.api.errors import (
- AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
-)
from synapse.api.constants import EventTypes, Membership, RejectedReason
-from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
- preserve_fn, preserve_context_over_deferred
+from synapse.api.errors import (
+ AuthError,
+ CodeMessageException,
+ FederationDeniedError,
+ FederationError,
+ StoreError,
+ SynapseError,
)
-from synapse.util.metrics import measure_func
-from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, Linearizer
-from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
- compute_event_signature, add_hashes_and_signatures,
+ add_hashes_and_signatures,
+ compute_event_signature,
)
+from synapse.events.validator import EventValidator
+from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id
-
-from synapse.events.utils import prune_event
-
-from synapse.util.retryutils import NotRetryingDestination
-
+from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async import Linearizer
from synapse.util.distributor import user_joined_room
+from synapse.util.frozenutils import unfreeze
+from synapse.util.logutils import log_function
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.visibility import filter_events_for_server
-from twisted.internet import defer
-
-import itertools
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -70,14 +76,16 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
- self.replication_layer = hs.get_replication_layer()
+ self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
-
- self.replication_layer.set_handler(self)
+ self.pusher_pool = hs.get_pusherpool()
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._server_notices_mxid = hs.config.server_notices_mxid
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -85,7 +93,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def on_receive_pdu(self, origin, pdu, get_missing=True):
+ def on_receive_pdu(
+ self, origin, pdu, get_missing=True, sent_to_us_directly=False,
+ ):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -99,8 +109,10 @@ class FederationHandler(BaseHandler):
"""
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
+ existing = yield self.store.get_event(
+ pdu.event_id,
+ allow_none=True,
+ allow_rejected=True,
)
# FIXME: Currently we fetch an event again when we already have it
@@ -116,6 +128,19 @@ class FederationHandler(BaseHandler):
logger.debug("Already seen pdu %s", pdu.event_id)
return
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ raise FederationError(
+ "ERROR",
+ err.code,
+ err.msg,
+ affected=pdu.event_id,
+ )
+
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if pdu.room_id in self.room_queues:
@@ -124,15 +149,30 @@ class FederationHandler(BaseHandler):
self.room_queues[pdu.room_id].append((pdu, origin))
return
- state = None
-
- auth_chain = []
-
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
+ # If we're no longer in the room just ditch the event entirely. This
+ # is probably an old server that has come back and thinks we're still
+ # in the room (or we've been rejoined to the room by a state reset).
+ #
+ # If we were never in the room then maybe our database got vaped and
+ # we should check if we *are* in fact in the room. If we are then we
+ # can magically rejoin the room.
+ is_in_room = yield self.auth.check_host_in_room(
+ pdu.room_id,
+ self.server_name
)
+ if not is_in_room:
+ was_in_room = yield self.store.was_host_joined(
+ pdu.room_id, self.server_name,
+ )
+ if was_in_room:
+ logger.info(
+ "Ignoring PDU %s for room %s from %s as we've left the room!",
+ pdu.event_id, pdu.room_id, origin,
+ )
+ defer.returnValue(None)
- fetch_state = False
+ state = None
+ auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
@@ -147,7 +187,7 @@ class FederationHandler(BaseHandler):
)
prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -175,8 +215,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to
# fetch the missing stuff
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.iterkeys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -189,26 +228,60 @@ class FederationHandler(BaseHandler):
list(prevs - seen)[:5],
)
- if prevs - seen:
- logger.info(
- "Still missing %d events for room %r: %r...",
- len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ if sent_to_us_directly and prevs - seen:
+ # If they have sent it to us directly, and the server
+ # isn't telling us about the auth events that it's
+ # made a message referencing, we explode
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
)
- fetch_state = True
+ elif prevs - seen:
+ # Calculate the state of the previous events, and
+ # de-conflict them to find the current state.
+ state_groups = []
+ auth_chains = set()
+ try:
+ # Get the state of the events we know about
+ ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
+ state_groups.append(ours)
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in prevs - seen:
+ state, got_auth_chain = (
+ yield self.replication_layer.get_state_for_room(
+ origin, pdu.room_id, p
+ )
+ )
+ auth_chains.update(got_auth_chain)
+ state_group = {(x.type, x.state_key): x.event_id for x in state}
+ state_groups.append(state_group)
+
+ # Resolve any conflicting state
+ def fetch(ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False
+ )
- if fetch_state:
- # We need to get the state at this event, since we haven't
- # processed all the prev events.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- try:
- state, auth_chain = yield self.replication_layer.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
- except:
- logger.exception("Failed to get state for event: %s", pdu.event_id)
+ state_map = yield resolve_events_with_factory(
+ state_groups, {pdu.event_id: pdu}, fetch
+ )
+
+ state = (yield self.store.get_events(state_map.values())).values()
+ auth_chain = list(auth_chains)
+ except Exception:
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=pdu.event_id,
+ )
yield self._process_received_pdu(
origin,
@@ -227,8 +300,7 @@ class FederationHandler(BaseHandler):
min_depth (int): Minimum depth of events to return.
"""
# We recalculate seen, since it may have changed.
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
return
@@ -287,11 +359,17 @@ class FederationHandler(BaseHandler):
for e in missing_events:
logger.info("Handling found event %s", e.event_id)
- yield self.on_receive_pdu(
- origin,
- e,
- get_missing=False
- )
+ try:
+ yield self.on_receive_pdu(
+ origin,
+ e,
+ get_missing=False
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warn("Event %s failed history check.")
+ else:
+ raise
@log_function
@defer.inlineCallbacks
@@ -340,9 +418,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -410,7 +486,10 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
- prev_state_id = context.prev_state_ids.get(
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ prev_state_id = prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
@@ -424,91 +503,21 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- @measure_func("_filter_events_for_server")
- @defer.inlineCallbacks
- def _filter_events_for_server(self, server_name, room_id, events):
- event_to_state_ids = yield self.store.get_state_ids_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
- )
- )
-
- # We only want to pull out member events that correspond to the
- # server's domain.
-
- def check_match(id):
- try:
- return server_name == get_domain_from_id(id)
- except:
- return False
-
- # Parses mapping `event_id -> (type, state_key) -> state event_id`
- # to get all state ids that we're interested in.
- event_map = yield self.store.get_events([
- e_id
- for key_to_eid in event_to_state_ids.values()
- for key, e_id in key_to_eid.items()
- if key[0] != EventTypes.Member or check_match(key[1])
- ])
-
- event_to_state = {
- e_id: {
- key: event_map[inner_e_id]
- for key, inner_e_id in key_to_eid.items()
- if inner_e_id in event_map
- }
- for e_id, key_to_eid in event_to_state_ids.items()
- }
-
- def redact_disallowed(event, state):
- if not state:
- return event
-
- history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
- if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
- # We now loop through all state events looking for
- # membership states for the requesting server to determine
- # if the server is either in the room or has been invited
- # into the room.
- for ev in state.values():
- if ev.type != EventTypes.Member:
- continue
- try:
- domain = get_domain_from_id(ev.state_key)
- except:
- continue
-
- if domain != server_name:
- continue
-
- memtype = ev.membership
- if memtype == Membership.JOIN:
- return event
- elif memtype == Membership.INVITE:
- if visibility == "invited":
- return event
- else:
- return prune_event(event)
-
- return event
-
- defer.returnValue([
- redact_disallowed(e, event_to_state[e.event_id])
- for e in events
- ])
-
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
- This will attempt to get more events from the remote. This may return
- be successfull and still return no events if the other side has no new
- events to offer.
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
@@ -520,6 +529,16 @@ class FederationHandler(BaseHandler):
extremities=extremities,
)
+ # ideally we'd sanity check the events here for excess prev_events etc,
+ # but it's hard to reject events at this point without completely
+ # breaking backfill in the same way that it is currently broken by
+ # events whose signature we cannot verify (#3121).
+ #
+ # So for now we accept the events anyway. #3124 tracks this.
+ #
+ # for ev in events:
+ # self._sanity_check_event(ev)
+
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
@@ -590,9 +609,10 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.run_in_background(
+ self.replication_layer.get_pdu,
[dest],
event_id,
outlier=True,
@@ -612,7 +632,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_events(
+ seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -702,9 +722,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
+ """Get joined domains from state
+
+ Args:
+ state (dict[tuple, FrozenEvent]): State map from type/state
+ key to event.
+
+ Returns:
+ list[tuple[str, int]]: Returns a list of servers with the
+ lowest depth of their joins. Sorted by lowest depth first.
+ """
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
+ for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@@ -718,7 +748,7 @@ class FederationHandler(BaseHandler):
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
- except:
+ except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
@@ -738,7 +768,7 @@ class FederationHandler(BaseHandler):
yield self.backfill(
dom, room_id,
limit=100,
- extremities=[e for e in extremities.keys()]
+ extremities=extremities,
)
# If this succeeded then we probably already have the
# appropriate stuff.
@@ -762,6 +792,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e)
+ continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
@@ -784,38 +817,76 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- states = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
- for e in event_ids
- ]))
+ resolve = logcontext.preserve_fn(
+ self.state_handler.resolve_state_groups_for_events
+ )
+ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids],
+ consumeErrors=True,
+ ))
+
+ # dict[str, dict[tuple, str]], a map from event_id to state map of
+ # event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
- [e_id for ids in states.values() for e_id in ids],
+ [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in state_dict.items()
+ for k, e_id in iteritems(state_dict)
if e_id in state_map
- } for key, state_dict in states.items()
+ } for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
- dom for dom in likely_domains
+ dom for dom, _ in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
- tried_domains.update(likely_domains)
+ tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False)
+ def _sanity_check_event(self, ev):
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev (synapse.events.EventBase): event to be checked
+
+ Returns: None
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_events) > 20:
+ logger.warn("Rejecting event %s which has %i prev_events",
+ ev.event_id, len(ev.prev_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many prev_events",
+ )
+
+ if len(ev.auth_events) > 10:
+ logger.warn("Rejecting event %s which has %i auth_events",
+ ev.event_id, len(ev.auth_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many auth_events",
+ )
+
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -838,16 +909,6 @@ class FederationHandler(BaseHandler):
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
-
- for event in auth:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
defer.returnValue([e for e in auth])
@log_function
@@ -916,7 +977,7 @@ class FederationHandler(BaseHandler):
room_creator_user_id="",
is_public=False
)
- except:
+ except Exception:
# FIXME
pass
@@ -940,9 +1001,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
- room_queue
- )
+ logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@@ -982,8 +1041,7 @@ class FederationHandler(BaseHandler):
})
try:
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -1051,13 +1109,15 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- state_ids = context.prev_state_ids.values()
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
- state = yield self.store.get_events(context.prev_state_ids.values())
+ state = yield self.store.get_events(list(prev_state_ids.values()))
defer.returnValue({
- "state": state.values(),
+ "state": list(state.values()),
"auth_chain": auth_chain,
})
@@ -1069,10 +1129,23 @@ class FederationHandler(BaseHandler):
"""
event = pdu
+ if event.state_key is None:
+ raise SynapseError(400, "The invite event did not have a state key")
+
is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
+ if self.hs.config.block_non_admin_invites:
+ raise SynapseError(403, "This server does not accept room invites")
+
+ if not self.spam_checker.user_may_invite(
+ event.sender, event.state_key, event.room_id,
+ ):
+ raise SynapseError(
+ 403, "This user is not permitted to send invites to this server/user"
+ )
+
membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event")
@@ -1081,12 +1154,16 @@ class FederationHandler(BaseHandler):
if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it")
- if event.state_key is None:
- raise SynapseError(400, "The invite event did not have a state key")
-
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
+ # block any attempts to invite the server notices mxid
+ if event.state_key == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -1213,8 +1290,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1268,14 +1344,12 @@ class FederationHandler(BaseHandler):
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
if state_groups:
- _, state = state_groups.items().pop()
+ _, state = list(iteritems(state_groups)).pop()
results = {
(e.type, e.state_key): e for e in state
}
@@ -1291,19 +1365,7 @@ class FederationHandler(BaseHandler):
else:
del results[(event.type, event.state_key)]
- res = results.values()
- for event in res:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
+ res = list(results.values())
defer.returnValue(res)
else:
defer.returnValue([])
@@ -1312,8 +1374,6 @@ class FederationHandler(BaseHandler):
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
- yield run_on_reactor()
-
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
@@ -1332,7 +1392,7 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(results.values())
+ defer.returnValue(list(results.values()))
else:
defer.returnValue([])
@@ -1349,17 +1409,26 @@ class FederationHandler(BaseHandler):
limit
)
- events = yield self._filter_events_for_server(origin, room_id, events)
+ events = yield filter_events_for_server(self.store, origin, events)
defer.returnValue(events)
@defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
+ def get_persisted_pdu(self, origin, event_id):
+ """Get an event from the database for the given server.
+
+ Args:
+ origin [str]: hostname of server which is requesting the event; we
+ will check that the server is allowed to see it.
+ event_id [str]: id of the event being requested
Returns:
- Deferred: Results in a `Pdu`.
+ Deferred[EventBase|None]: None if we know nothing about the event;
+ otherwise the (possibly-redacted) event.
+
+ Raises:
+ AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event_id,
@@ -1368,32 +1437,17 @@ class FederationHandler(BaseHandler):
)
if event:
- if self.is_mine_id(event.event_id):
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
- if do_auth:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- events = yield self._filter_events_for_server(
- origin, event.room_id, [event]
- )
-
- event = events[0]
+ in_room = yield self.auth.check_host_in_room(
+ event.room_id,
+ origin
+ )
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+ events = yield filter_events_for_server(
+ self.store, origin, [event],
+ )
+ event = events[0]
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1412,22 +1466,33 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
- if not event.internal_metadata.is_outlier():
- yield self.action_generator.handle_push_actions_for_event(
- event, context
+ try:
+ if not event.internal_metadata.is_outlier() and not backfilled:
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ tp, value, tb = sys.exc_info()
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- )
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
+ logcontext.run_in_background(
+ self.pusher_pool.on_new_notifications,
+ event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1439,22 +1504,23 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
- contexts = yield preserve_context_over_deferred(defer.gatherResults(
+ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._prep_event)(
+ logcontext.run_in_background(
+ self._prep_event,
origin,
ev_info["event"],
state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
- ]
+ ], consumeErrors=True,
))
yield self.store.persist_events(
[
(ev_info["event"], context)
- for ev_info, context in itertools.izip(event_infos, contexts)
+ for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
)
@@ -1574,8 +1640,9 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@@ -1605,7 +1672,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
- if event.type == EventTypes.GuestAccess:
+ if event.type == EventTypes.GuestAccess and not context.rejected:
yield self.maybe_kick_guest_users(event)
defer.returnValue(context)
@@ -1635,15 +1702,6 @@ class FederationHandler(BaseHandler):
local_auth_chain, remote_auth_chain
)
- for event in ret["auth_chain"]:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
-
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@@ -1669,11 +1727,26 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
)
+ missing_events = yield filter_events_for_server(
+ self.store, origin, missing_events,
+ )
+
defer.returnValue(missing_events)
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
+ """
+
+ Args:
+ origin (str):
+ event (synapse.events.FrozenEvent):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->str]):
+
+ Returns:
+ defer.Deferred[None]
+ """
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1684,7 +1757,8 @@ class FederationHandler(BaseHandler):
event_key = None
if event_auth_events - current_state:
- have_events = yield self.store.have_events(
+ # TODO: can we use store.have_seen_events here instead?
+ have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state
)
else:
@@ -1707,12 +1781,12 @@ class FederationHandler(BaseHandler):
origin, event.room_id, event.event_id
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
- if e.event_id in seen_remotes.keys():
+ if e.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
@@ -1739,11 +1813,11 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.have_events(
+ have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
@@ -1756,18 +1830,18 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
- different_events = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self.store.get_event)(
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
@@ -1777,7 +1851,7 @@ class FederationHandler(BaseHandler):
})
new_state = self.state_handler.resolve_events(
- [local_view.values(), remote_view.values()],
+ [list(local_view.values()), list(remote_view.values())],
event
)
@@ -1786,16 +1860,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@@ -1815,9 +1882,10 @@ class FederationHandler(BaseHandler):
break
if do_resolution:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids
+ event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
@@ -1832,13 +1900,13 @@ class FederationHandler(BaseHandler):
local_auth_chain,
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes.keys():
+ if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
@@ -1868,23 +1936,16 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- except:
+ except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- if k != event_key
- })
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.items()
- })
- context.state_group = self.store.get_next_state_group()
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
try:
self.auth.check(event, auth_events=auth_events)
@@ -1893,6 +1954,58 @@ class FederationHandler(BaseHandler):
raise e
@defer.inlineCallbacks
+ def _update_context_for_auth_events(self, event, context, auth_events,
+ event_key):
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event (Event): The event we're handling the context for
+
+ context (synapse.events.snapshot.EventContext): event context
+ to be updated
+
+ auth_events (dict[(str, str)->str]): Events to update in the event
+ context.
+
+ event_key ((str, str)): (type, state_key) for the current event.
+ this will not be included in the current_state in the context.
+ """
+ state_updates = {
+ k: a.event_id for k, a in iteritems(auth_events)
+ if k != event_key
+ }
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = dict(current_state_ids)
+
+ current_state_ids.update(state_updates)
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({
+ k: a.event_id for k, a in iteritems(auth_events)
+ })
+
+ # create a new state group as a delta from the existing one.
+ prev_group = context.state_group
+ state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ current_state_ids=current_state_ids,
+ )
+
+ yield context.update_state(
+ state_group=state_group,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ )
+
+ @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
@@ -1934,8 +2047,8 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
- return it.next()
- except:
+ return next(it)
+ except Exception:
return opt
current_local = get_next(local_iter)
@@ -2060,8 +2173,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -2076,7 +2188,7 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
- member_handler = self.hs.get_handlers().room_member_handler
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@@ -2089,10 +2201,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ """Handle an exchange_third_party_invite request from a remote server
+
+ The remote server will call this when it wants to turn a 3pid invite
+ into a normal m.room.member invite.
+
+ Returns:
+ Deferred: resolves (to None)
+ """
builder = self.event_builder_factory.new(event_dict)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -2107,10 +2226,13 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+ # XXX we send the invite here, but send_membership_event also sends it,
+ # so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
- member_handler = self.hs.get_handlers().room_member_handler
+
+ member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
@@ -2120,7 +2242,8 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
- original_invite_id = context.prev_state_ids.get(key)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
@@ -2139,8 +2262,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(builder=builder)
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ )
defer.returnValue((event, context))
@defer.inlineCallbacks
@@ -2161,7 +2285,8 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- invite_event_id = context.prev_state_ids.get(
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ invite_event_id = prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
new file mode 100644
index 0000000000..53e5e2648b
--- /dev/null
+++ b/synapse/handlers/groups_local.py
@@ -0,0 +1,473 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+
+logger = logging.getLogger(__name__)
+
+
+def _create_rerouter(func_name):
+ """Returns a function that looks at the group id and calls the function
+ on federation or the local group server if the group is local
+ """
+ def f(self, group_id, *args, **kwargs):
+ if self.is_mine_id(group_id):
+ return getattr(self.groups_server_handler, func_name)(
+ group_id, *args, **kwargs
+ )
+ else:
+ destination = get_domain_from_id(group_id)
+ return getattr(self.transport_client, func_name)(
+ destination, group_id, *args, **kwargs
+ )
+ return f
+
+
+class GroupsLocalHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.groups_server_handler = hs.get_groups_server_handler()
+ self.transport_client = hs.get_federation_transport_client()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.notifier = hs.get_notifier()
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ # The following functions merely route the query to the local groups server
+ # or federation depending on if the group is local or remote
+
+ get_group_profile = _create_rerouter("get_group_profile")
+ update_group_profile = _create_rerouter("update_group_profile")
+ get_rooms_in_group = _create_rerouter("get_rooms_in_group")
+
+ get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
+
+ add_room_to_group = _create_rerouter("add_room_to_group")
+ update_room_in_group = _create_rerouter("update_room_in_group")
+ remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+ update_group_summary_room = _create_rerouter("update_group_summary_room")
+ delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+ update_group_category = _create_rerouter("update_group_category")
+ delete_group_category = _create_rerouter("delete_group_category")
+ get_group_category = _create_rerouter("get_group_category")
+ get_group_categories = _create_rerouter("get_group_categories")
+
+ update_group_summary_user = _create_rerouter("update_group_summary_user")
+ delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+ update_group_role = _create_rerouter("update_group_role")
+ delete_group_role = _create_rerouter("delete_group_role")
+ get_group_role = _create_rerouter("get_group_role")
+ get_group_roles = _create_rerouter("get_group_roles")
+
+ set_group_join_policy = _create_rerouter("set_group_join_policy")
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the group summary for a group.
+
+ If the group is remote we check that the users have valid attestations.
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_group_summary(
+ group_id, requester_user_id
+ )
+ else:
+ res = yield self.transport_client.get_group_summary(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ group_server_name = get_domain_from_id(group_id)
+
+ # Loop through the users and validate the attestations.
+ chunk = res["users_section"]["users"]
+ valid_users = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_users.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["users_section"]["users"] = valid_users
+
+ res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
+ res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
+
+ # Add `is_publicised` flag to indicate whether the user has publicised their
+ # membership of the group on their profile
+ result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ is_publicised = group_id in result
+
+ res.setdefault("user", {})["is_publicised"] = is_publicised
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, content):
+ """Create a group
+ """
+
+ logger.info("Asking to create group with ID: %r", group_id)
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.create_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+
+ res = yield self.transport_client.create_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ is_publicised = content.get("publicise", False)
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=True,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get users in a group
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+ defer.returnValue(res)
+
+ group_server_name = get_domain_from_id(group_id)
+
+ res = yield self.transport_client.get_users_in_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ chunk = res["chunk"]
+ valid_entries = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_entries.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["chunk"] = valid_entries
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, user_id, content):
+ """Request to join a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.join_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.join_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, user_id, content):
+ """Accept an invite to a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.accept_invite(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.accept_group_invite(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite(self, group_id, user_id, requester_user_id, config):
+ """Invite a user to a group
+ """
+ content = {
+ "requester_user_id": requester_user_id,
+ "config": config,
+ }
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ res = yield self.transport_client.invite_to_group(
+ get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def on_invite(self, group_id, user_id, content):
+ """One of our users were invited to a group
+ """
+ # TODO: Support auto join and rejection
+
+ if not self.is_mine_id(user_id):
+ raise SynapseError(400, "User not on this server")
+
+ local_profile = {}
+ if "profile" in content:
+ if "name" in content["profile"]:
+ local_profile["name"] = content["profile"]["name"]
+ if "avatar_url" in content["profile"]:
+ local_profile["avatar_url"] = content["profile"]["avatar_url"]
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="invite",
+ content={"profile": local_profile, "inviter": content["inviter"]},
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+ try:
+ user_profile = yield self.profile_handler.get_profile(user_id)
+ except Exception as e:
+ logger.warn("No profile for user %s: %s", user_id, e)
+ user_profile = {}
+
+ defer.returnValue({"state": "invite", "user_profile": user_profile})
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from a group
+ """
+ if user_id == requester_user_id:
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ # TODO: Should probably remember that we tried to leave so that we can
+ # retry if the group server is currently down.
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ content["requester_user_id"] = requester_user_id
+ res = yield self.transport_client.remove_user_from_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ user_id, content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def user_removed_from_group(self, group_id, user_id, content):
+ """One of our users was removed/kicked from a group
+ """
+ # TODO: Check if user in group
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_groups(self, user_id):
+ group_ids = yield self.store.get_joined_groups(user_id)
+ defer.returnValue({"groups": group_ids})
+
+ @defer.inlineCallbacks
+ def get_publicised_groups_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ result = yield self.store.get_publicised_groups_for_user(user_id)
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ result.extend(app_service.get_groups_for_user(user_id))
+
+ defer.returnValue({"groups": result})
+ else:
+ bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ get_domain_from_id(user_id), [user_id],
+ )
+ result = bulk_result.get("users", {}).get(user_id)
+ # TODO: Verify attestations
+ defer.returnValue({"groups": result})
+
+ @defer.inlineCallbacks
+ def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ destinations = {}
+ local_users = set()
+
+ for user_id in user_ids:
+ if self.hs.is_mine_id(user_id):
+ local_users.add(user_id)
+ else:
+ destinations.setdefault(
+ get_domain_from_id(user_id), set()
+ ).add(user_id)
+
+ if not proxy and destinations:
+ raise SynapseError(400, "Some user_ids are not local")
+
+ results = {}
+ failed_results = []
+ for destination, dest_user_ids in iteritems(destinations):
+ try:
+ r = yield self.transport_client.bulk_get_publicised_groups(
+ destination, list(dest_user_ids),
+ )
+ results.update(r["users"])
+ except Exception:
+ failed_results.extend(dest_user_ids)
+
+ for uid in local_users:
+ results[uid] = yield self.store.get_publicised_groups_for_user(
+ uid
+ )
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ results[uid].extend(app_service.get_groups_for_user(uid))
+
+ defer.returnValue({"users": results})
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9efcdff1d6..8c8aedb2b8 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,17 +16,21 @@
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
+
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import (
- MatrixCodeMessageException, CodeMessageException
+ CodeMessageException,
+ Codes,
+ MatrixCodeMessageException,
+ SynapseError,
)
-from ._base import BaseHandler
-from synapse.util.async import run_on_reactor
-from synapse.api.errors import SynapseError, Codes
-import json
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -36,6 +41,7 @@ class IdentityHandler(BaseHandler):
super(IdentityHandler, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
+ self.federation_http_client = hs.get_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
@@ -58,8 +64,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
- yield run_on_reactor()
-
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@@ -102,7 +106,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
- yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
@@ -137,9 +140,53 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
- yield run_on_reactor()
+ def unbind_threepid(self, mxid, threepid):
+ """
+ Removes a binding from an identity server
+ Args:
+ mxid (str): Matrix user ID of binding to be removed
+ threepid (dict): Dict with medium & address of binding to be removed
+
+ Returns:
+ Deferred[bool]: True on success, otherwise False
+ """
+ logger.debug("unbinding threepid %r from %s", threepid, mxid)
+ if not self.trusted_id_servers:
+ logger.warn("Can't unbind threepid: no trusted ID servers set in config")
+ defer.returnValue(False)
+
+ # We don't track what ID server we added 3pids on (perhaps we ought to)
+ # but we assume that any of the servers in the trusted list are in the
+ # same ID server federation, so we can pick any one of them to send the
+ # deletion request to.
+ id_server = next(iter(self.trusted_id_servers))
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+ content = {
+ "mxid": mxid,
+ "threepid": threepid,
+ }
+ headers = {}
+ # we abuse the federation http client to sign the request, but we have to send it
+ # using the normal http client since we don't want the SRV lookup and want normal
+ # 'browser-like' HTTPS.
+ self.federation_http_client.sign_request(
+ destination=None,
+ method='POST',
+ url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ headers_dict=headers,
+ content=content,
+ destination_is=id_server,
+ )
+ yield self.http_client.post_json_get_json(
+ url,
+ content,
+ headers,
+ )
+ defer.returnValue(True)
+ @defer.inlineCallbacks
+ def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
@@ -174,8 +221,6 @@ class IdentityHandler(BaseHandler):
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
- yield run_on_reactor()
-
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 10f5f35a69..40e7580a61 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -21,20 +23,15 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
-from synapse.types import (
- UserID, StreamToken,
-)
+from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -151,22 +148,25 @@ class InitialSyncHandler(BaseHandler):
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
- deferred_room_state = self.state_handler.get_current_state(
- event.room_id
+ deferred_room_state = run_in_background(
+ self.state_handler.get_current_state,
+ event.room_id,
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
- deferred_room_state = self.store.get_state_for_events(
- [event.event_id], None
+ deferred_room_state = run_in_background(
+ self.store.get_state_for_events,
+ [event.event_id], None,
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield preserve_context_over_deferred(
+ (messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(
+ self.store.get_recent_events_for_room,
event.room_id,
limit=limit,
end_token=room_end_token,
@@ -180,8 +180,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token.copy_and_replace("room_key", room_end_token)
time_now = self.clock.time_msec()
d["messages"] = {
@@ -214,7 +214,7 @@ class InitialSyncHandler(BaseHandler):
})
d["account_data"] = account_data_events
- except:
+ except Exception:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
@@ -324,8 +324,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token[0])
- end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+ start_token = StreamToken.START.copy_and_replace("room_key", token)
+ end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
time_now = self.clock.time_msec()
@@ -389,25 +389,28 @@ class InitialSyncHandler(BaseHandler):
receipts = []
defer.returnValue(receipts)
- presence, receipts, (messages, token) = yield defer.gatherResults(
- [
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
- room_id,
- limit=limit,
- end_token=now_token.room_key,
- )
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(get_presence),
+ run_in_background(get_receipts),
+ run_in_background(
+ self.store.get_recent_events_for_room,
+ room_id,
+ limit=limit,
+ end_token=now_token.room_key,
+ )
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError),
+ )
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 24c9ffdb20..39d7724778 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,173 +13,185 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import sys
+
+import six
+from six import iteritems, itervalues, string_types
+
+from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
+from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
-from synapse.types import (
- UserID, RoomAlias, RoomStreamToken,
-)
-from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn
+from synapse.replication.http.send_event import send_event_to_master
+from synapse.types import RoomAlias, UserID
+from synapse.util.async import Linearizer
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
-from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
-from canonicaljson import encode_canonical_json
-
-import logging
-import random
-import ujson
-
logger = logging.getLogger(__name__)
-class MessageHandler(BaseHandler):
+class MessageHandler(object):
+ """Contains some read only APIs to get state about a room
+ """
def __init__(self, hs):
- super(MessageHandler, self).__init__(hs)
- self.hs = hs
- self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
self.clock = hs.get_clock()
- self.validator = EventValidator()
-
- self.pagination_lock = ReadWriteLock()
-
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Limiter(max_count=5)
-
- self.action_generator = hs.get_action_generator()
+ self.state = hs.get_state_handler()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
- def purge_history(self, room_id, event_id):
- event = yield self.store.get_event(event_id)
+ def get_room_data(self, user_id=None, room_id=None,
+ event_type=None, state_key="", is_guest=False):
+ """ Get data from a room.
- if event.room_id != room_id:
- raise SynapseError(400, "Event is for wrong room.")
+ Args:
+ event : The room path event
+ Returns:
+ The path data content.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
- depth = event.depth
+ if membership == Membership.JOIN:
+ data = yield self.state.get_current_state(
+ room_id, event_type, state_key
+ )
+ elif membership == Membership.LEAVE:
+ key = (event_type, state_key)
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], [key]
+ )
+ data = room_state[membership_event_id].get(key)
- with (yield self.pagination_lock.write(room_id)):
- yield self.store.delete_old_state(room_id, depth)
+ defer.returnValue(data)
@defer.inlineCallbacks
- def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True, event_filter=None):
- """Get messages in a room.
+ def get_state_events(self, user_id, room_id, is_guest=False):
+ """Retrieve all state events for a given room. If the user is
+ joined to the room then return the current state. If the user has
+ left the room return the state events from when they left.
Args:
- requester (Requester): The user requesting messages.
- room_id (str): The room they want messages from.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config rules to apply, if any.
- as_client_event (bool): True to get events in client-server format.
- event_filter (Filter): Filter to apply to results or None
+ user_id(str): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
Returns:
- dict: Pagination API results
+ A list of dicts representing state events. [{}, {}, {}]
"""
- user_id = requester.user.to_string()
+ membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
- if pagin_config.from_token:
- room_token = pagin_config.from_token.room_key
- else:
- pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_room(
- room_id=room_id
- )
+ if membership == Membership.JOIN:
+ room_state = yield self.state.get_current_state(room_id)
+ elif membership == Membership.LEAVE:
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], None
)
- room_token = pagin_config.from_token.room_key
-
- room_token = RoomStreamToken.parse(room_token)
+ room_state = room_state[membership_event_id]
- pagin_config.from_token = pagin_config.from_token.copy_and_replace(
- "room_key", str(room_token)
+ now = self.clock.time_msec()
+ defer.returnValue(
+ [serialize_event(c, now) for c in room_state.values()]
)
- source_config = pagin_config.get_source_config("room")
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
- with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self._check_in_room_or_world_readable(
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self.auth.check_in_room_or_world_readable(
room_id, user_id
)
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
- if source_config.direction == 'b':
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- max_topo = room_token.topological
- else:
- max_topo = yield self.store.get_max_topological_token(
- room_id, room_token.stream
- )
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
- )
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
- source_config.from_key = str(leave_token)
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or because there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
- )
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in iteritems(users_with_profile)
+ })
- events, next_key = yield self.store.paginate_room_events(
- room_id=room_id,
- from_key=source_config.from_key,
- to_key=source_config.to_key,
- direction=source_config.direction,
- limit=source_config.limit,
- event_filter=event_filter,
- )
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+class EventCreationHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.validator = EventValidator()
+ self.profile_handler = hs.get_profile_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.server_name = hs.hostname
+ self.ratelimiter = hs.get_ratelimiter()
+ self.notifier = hs.get_notifier()
+ self.config = hs.config
- if not events:
- defer.returnValue({
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- })
-
- if event_filter:
- events = event_filter.filter(events)
-
- events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
- )
+ self.http_client = hs.get_simple_http_client()
- time_now = self.clock.time_msec()
+ # This is only used to get at ratelimit function, and maybe_kick_guest_users
+ self.base_handler = BaseHandler(hs)
- chunk = {
- "chunk": [
- serialize_event(e, time_now, as_client_event)
- for e in events
- ],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- }
+ self.pusher_pool = hs.get_pusherpool()
- defer.returnValue(chunk)
+ # We arbitrarily limit concurrent event creation for a room to 5.
+ # This is to stop us from diverging history *too* much.
+ self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+
+ self.action_generator = hs.get_action_generator()
+
+ self.spam_checker = hs.get_spam_checker()
+
+ if self.config.block_events_without_consent_error is not None:
+ self._consent_uri_builder = ConsentURIBuilder(self.config)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_event_ids=None):
+ prev_events_and_hashes=None):
"""
Given a dict from a client, create a new event.
@@ -192,50 +205,143 @@ class MessageHandler(BaseHandler):
event_dict (dict): An entire event
token_id (str)
txn_id (str)
- prev_event_ids (list): The prev event ids to use when creating the event
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
- with (yield self.limiter.queue(builder.room_id)):
- self.validator.validate_new(builder)
-
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in {Membership.JOIN, Membership.INVITE}:
- # If event doesn't include a display name, add one.
- profile = self.hs.get_handlers().profile_handler
- content = builder.content
-
- try:
- if "displayname" not in content:
- content["displayname"] = yield profile.get_displayname(target)
- if "avatar_url" not in content:
- content["avatar_url"] = yield profile.get_avatar_url(target)
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s",
- target, e
- )
+ self.validator.validate_new(builder)
+
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ target = UserID.from_string(builder.state_key)
+
+ if membership in {Membership.JOIN, Membership.INVITE}:
+ # If event doesn't include a display name, add one.
+ profile = self.profile_handler
+ content = builder.content
+
+ try:
+ if "displayname" not in content:
+ content["displayname"] = yield profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+ except Exception as e:
+ logger.info(
+ "Failed to get profile information for %r: %s",
+ target, e
+ )
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
+ if not is_exempt:
+ yield self.assert_accepted_privacy_policy(requester)
- if txn_id is not None:
- builder.internal_metadata.txn_id = txn_id
+ if token_id is not None:
+ builder.internal_metadata.token_id = token_id
- event, context = yield self._create_new_client_event(
- builder=builder,
- requester=requester,
- prev_event_ids=prev_event_ids,
- )
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
+
+ event, context = yield self.create_new_client_event(
+ builder=builder,
+ requester=requester,
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
defer.returnValue((event, context))
+ def _is_exempt_from_privacy_policy(self, builder, requester):
+ """"Determine if an event to be sent is exempt from having to consent
+ to the privacy policy
+
+ Args:
+ builder (synapse.events.builder.EventBuilder): event being created
+ requester (Requster): user requesting this event
+
+ Returns:
+ Deferred[bool]: true if the event can be sent without the user
+ consenting
+ """
+ # the only thing the user can do is join the server notices room.
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ if membership == Membership.JOIN:
+ return self._is_server_notices_room(builder.room_id)
+ elif membership == Membership.LEAVE:
+ # the user is always allowed to leave (but not kick people)
+ return builder.state_key == requester.user.to_string()
+ return succeed(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notices_room(self, room_id):
+ if self.config.server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self.config.server_notices_mxid in user_ids)
+
+ @defer.inlineCallbacks
+ def assert_accepted_privacy_policy(self, requester):
+ """Check if a user has accepted the privacy policy
+
+ Called when the given user is about to do something that requires
+ privacy consent. We see if the user is exempt and otherwise check that
+ they have given consent. If they have not, a ConsentNotGiven error is
+ raised.
+
+ Args:
+ requester (synapse.types.Requester):
+ The user making the request
+
+ Returns:
+ Deferred[None]: returns normally if the user has consented or is
+ exempt
+
+ Raises:
+ ConsentNotGivenError: if the user has not given consent yet
+ """
+ if self.config.block_events_without_consent_error is None:
+ return
+
+ # exempt AS users from needing consent
+ if requester.app_service is not None:
+ return
+
+ user_id = requester.user.to_string()
+
+ # exempt the system notices user
+ if (
+ self.config.server_notices_mxid is not None and
+ user_id == self.config.server_notices_mxid
+ ):
+ return
+
+ u = yield self.store.get_user_by_id(user_id)
+ assert u is not None
+ if u["appservice_id"] is not None:
+ # users registered by an appservice are exempt
+ return
+ if u["consent_version"] == self.config.user_consent_version:
+ return
+
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ requester.user.localpart,
+ )
+ msg = self.config.block_events_without_consent_error % {
+ 'consent_uri': consent_uri,
+ }
+ raise ConsentNotGivenError(
+ msg=msg,
+ consent_uri=consent_uri,
+ )
+
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
@@ -253,11 +359,6 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath"
)
- # We check here if we are currently being rate limited, so that we
- # don't do unnecessary work. We check again just before we actually
- # send the event.
- yield self.ratelimit(requester, update=False)
-
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -274,12 +375,6 @@ class MessageHandler(BaseHandler):
ratelimit=ratelimit,
)
- if event.type == EventTypes.Message:
- presence = self.hs.get_presence_handler()
- # We don't want to block sending messages on any presence code. This
- # matters as sometimes presence code can take a while.
- preserve_fn(presence.bump_presence_active_time)(user)
-
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
@@ -288,7 +383,8 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_event_id = prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -313,145 +409,85 @@ class MessageHandler(BaseHandler):
See self.create_event and self.send_nonmember_event.
"""
- event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
- )
- yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- )
- defer.returnValue(event)
- @defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
- """ Get data from a room.
-
- Args:
- event : The room path event
- Returns:
- The path data content.
- Raises:
- SynapseError if something went wrong.
- """
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
-
- if membership == Membership.JOIN:
- data = yield self.state_handler.get_current_state(
- room_id, event_type, state_key
- )
- elif membership == Membership.LEAVE:
- key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
+ # We limit the number of concurrent event sends in a room so that we
+ # don't fork the DAG too much. If we don't limit then we can end up in
+ # a situation where event persistence can't keep up, causing
+ # extremities to pile up, which in turn leads to state resolution
+ # taking longer.
+ with (yield self.limiter.queue(event_dict["room_id"])):
+ event, context = yield self.create_event(
+ requester,
+ event_dict,
+ token_id=requester.access_token_id,
+ txn_id=txn_id
)
- data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ spam_error = self.spam_checker.check_event_for_spam(event)
+ if spam_error:
+ if not isinstance(spam_error, string_types):
+ spam_error = "Spam is not permitted here"
+ raise SynapseError(
+ 403, spam_error, Codes.FORBIDDEN
+ )
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
- try:
- # check_user_was_in_room will return the most recent membership
- # event for the user if:
- # * The user is a non-guest user, and was ever in the room
- # * The user is a guest user, and has joined the room
- # else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
- return
- except AuthError:
- visibility = yield self.state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
- ):
- defer.returnValue((Membership.JOIN, None))
- return
- raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ yield self.send_nonmember_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
)
+ defer.returnValue(event)
+ @measure_func("create_new_client_event")
@defer.inlineCallbacks
- def get_state_events(self, user_id, room_id, is_guest=False):
- """Retrieve all state events for a given room. If the user is
- joined to the room then return the current state. If the user has
- left the room return the state events from when they left.
+ def create_new_client_event(self, builder, requester=None,
+ prev_events_and_hashes=None):
+ """Create a new event for a local client
Args:
- user_id(str): The user requesting state events.
- room_id(str): The room ID to get all state events from.
+ builder (EventBuilder):
+
+ requester (synapse.types.Requester|None):
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
+
Returns:
- A list of dicts representing state events. [{}, {}, {}]
+ Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
"""
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
- if membership == Membership.JOIN:
- room_state = yield self.state_handler.get_current_state(room_id)
- elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], None
+ if prev_events_and_hashes is not None:
+ assert len(prev_events_and_hashes) <= 10, \
+ "Attempting to create an event with %i prev_events" % (
+ len(prev_events_and_hashes),
)
- room_state = room_state[membership_event_id]
-
- now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
- )
-
- @measure_func("_create_new_client_event")
- @defer.inlineCallbacks
- def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
- if prev_event_ids:
- prev_events = yield self.store.add_event_hashes(prev_event_ids)
- prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
- depth = prev_max_depth + 1
else:
- latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
- builder.room_id,
- )
-
- # We want to limit the max number of prev events we point to in our
- # new event
- if len(latest_ret) > 10:
- # Sort by reverse depth, so we point to the most recent.
- latest_ret.sort(key=lambda a: -a[2])
- new_latest_ret = latest_ret[:5]
-
- # We also randomly point to some of the older events, to make
- # sure that we don't completely ignore the older events.
- if latest_ret[5:]:
- sample_size = min(5, len(latest_ret[5:]))
- new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
- latest_ret = new_latest_ret
-
- if latest_ret:
- depth = max([d for _, _, d in latest_ret]) + 1
- else:
- depth = 1
+ prev_events_and_hashes = \
+ yield self.store.get_prev_events_for_room(builder.room_id)
+
+ if prev_events_and_hashes:
+ depth = max([d for _, _, d in prev_events_and_hashes]) + 1
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(depth, MAX_DEPTH)
+ else:
+ depth = 1
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in latest_ret
- ]
+ prev_events = [
+ (event_id, prev_hashes)
+ for event_id, prev_hashes, _ in prev_events_and_hashes
+ ]
builder.prev_events = prev_events
builder.depth = depth
- state_handler = self.state_handler
-
- context = yield state_handler.compute_event_context(builder)
+ context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
@@ -470,8 +506,8 @@ class MessageHandler(BaseHandler):
event = builder.build()
logger.debug(
- "Created event %s with state: %s",
- event.event_id, context.prev_state_ids,
+ "Created event %s",
+ event.event_id,
)
defer.returnValue(
@@ -486,12 +522,21 @@ class MessageHandler(BaseHandler):
event,
context,
ratelimit=True,
- extra_users=[]
+ extra_users=[],
):
- # We now need to go and hit out to wherever we need to hit out to.
+ """Processes a new event. This includes checking auth, persisting it,
+ notifying users, sending to remote servers, etc.
- if ratelimit:
- yield self.ratelimit(requester)
+ If called from a worker will hit out to the master process for final
+ processing.
+
+ Args:
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
try:
yield self.auth.check_from_context(event, context)
@@ -501,13 +546,72 @@ class MessageHandler(BaseHandler):
# Ensure that we can round trip before trying to persist in db
try:
- dump = ujson.dumps(event.content)
- ujson.loads(dump)
- except:
+ dump = frozendict_json_encoder.encode(event.content)
+ json.loads(dump)
+ except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.maybe_kick_guest_users(event, context)
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ try:
+ # If we're a worker we need to hit out to the master.
+ if self.config.worker_app:
+ yield send_event_to_master(
+ clock=self.hs.get_clock(),
+ store=self.store,
+ client=self.http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ event=event,
+ context=context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ return
+
+ yield self.persist_and_notify_client_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+ except: # noqa: E722, as we reraise the exception this is fine.
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ tp, value, tb = sys.exc_info()
+
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
+
+ @defer.inlineCallbacks
+ def persist_and_notify_client_event(
+ self,
+ requester,
+ event,
+ context,
+ ratelimit=True,
+ extra_users=[],
+ ):
+ """Called when we have fully built the event, have already
+ calculated the push actions for the event, and checked auth.
+
+ This should only be run on master.
+ """
+ assert not self.config.worker_app
+
+ if ratelimit:
+ yield self.base_handler.ratelimit(requester)
+
+ yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -535,9 +639,11 @@ class MessageHandler(BaseHandler):
e.sender == event.sender
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
state_to_include_ids = [
e_id
- for k, e_id in context.current_state_ids.iteritems()
+ for k, e_id in iteritems(current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -551,7 +657,7 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
- for e in state_to_include.itervalues()
+ for e in itervalues(state_to_include)
]
invitee = UserID.from_string(event.state_key)
@@ -573,8 +679,9 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Redaction:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, context.prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@@ -594,15 +701,13 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
- if event.type == EventTypes.Create and context.prev_state_ids:
- raise AuthError(
- 403,
- "Changing the room create event is forbidden",
- )
-
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ if event.type == EventTypes.Create:
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ if prev_state_ids:
+ raise AuthError(
+ 403,
+ "Changing the room create event is forbidden",
+ )
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
@@ -610,16 +715,31 @@ class MessageHandler(BaseHandler):
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ run_in_background(
+ self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
)
- @defer.inlineCallbacks
def _notify():
- yield run_on_reactor()
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ try:
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room event")
+
+ run_in_background(_notify)
- preserve_fn(_notify)()
+ if event.type == EventTypes.Message:
+ # We don't want to block sending messages on any presence code. This
+ # matters as sometimes presence code can take a while.
+ run_in_background(self._bump_active_time, requester.user)
+
+ @defer.inlineCallbacks
+ def _bump_active_time(self, user):
+ try:
+ presence = self.hs.get_presence_handler()
+ yield presence.bump_presence_active_time(user)
+ except Exception:
+ logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
new file mode 100644
index 0000000000..b2849783ed
--- /dev/null
+++ b/synapse/handlers/pagination.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+
+from synapse.api.constants import Membership
+from synapse.api.errors import SynapseError
+from synapse.events.utils import serialize_event
+from synapse.types import RoomStreamToken
+from synapse.util.async import ReadWriteLock
+from synapse.util.logcontext import run_in_background
+from synapse.util.stringutils import random_string
+from synapse.visibility import filter_events_for_client
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeStatus(object):
+ """Object tracking the status of a purge request
+
+ This class contains information on the progress of a purge request, for
+ return by get_purge_status.
+
+ Attributes:
+ status (int): Tracks whether this request has completed. One of
+ STATUS_{ACTIVE,COMPLETE,FAILED}
+ """
+
+ STATUS_ACTIVE = 0
+ STATUS_COMPLETE = 1
+ STATUS_FAILED = 2
+
+ STATUS_TEXT = {
+ STATUS_ACTIVE: "active",
+ STATUS_COMPLETE: "complete",
+ STATUS_FAILED: "failed",
+ }
+
+ def __init__(self):
+ self.status = PurgeStatus.STATUS_ACTIVE
+
+ def asdict(self):
+ return {
+ "status": PurgeStatus.STATUS_TEXT[self.status]
+ }
+
+
+class PaginationHandler(object):
+ """Handles pagination and purge history requests.
+
+ These are in the same handler due to the fact we need to block clients
+ paginating during a purge.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ self.pagination_lock = ReadWriteLock()
+ self._purges_in_progress_by_room = set()
+ # map from purge id to PurgeStatus
+ self._purges_by_id = {}
+
+ def start_purge_history(self, room_id, token,
+ delete_local_events=False):
+ """Start off a history purge on a room.
+
+ Args:
+ room_id (str): The room to purge from
+
+ token (str): topological token to delete events before
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
+
+ Returns:
+ str: unique ID for this purge transaction.
+ """
+ if room_id in self._purges_in_progress_by_room:
+ raise SynapseError(
+ 400,
+ "History purge already in progress for %s" % (room_id, ),
+ )
+
+ purge_id = random_string(16)
+
+ # we log the purge_id here so that it can be tied back to the
+ # request id in the log lines.
+ logger.info("[purge] starting purge_id %s", purge_id)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+ run_in_background(
+ self._purge_history,
+ purge_id, room_id, token, delete_local_events,
+ )
+ return purge_id
+
+ @defer.inlineCallbacks
+ def _purge_history(self, purge_id, room_id, token,
+ delete_local_events):
+ """Carry out a history purge on a room.
+
+ Args:
+ purge_id (str): The id for this purge
+ room_id (str): The room to purge from
+ token (str): topological token to delete events before
+ delete_local_events (bool): True to delete local events as well as
+ remote ones
+
+ Returns:
+ Deferred
+ """
+ self._purges_in_progress_by_room.add(room_id)
+ try:
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.purge_history(
+ room_id, token, delete_local_events,
+ )
+ logger.info("[purge] complete")
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
+ except Exception:
+ logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
+ self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
+ finally:
+ self._purges_in_progress_by_room.discard(room_id)
+
+ # remove the purge from the list 24 hours after it completes
+ def clear_purge():
+ del self._purges_by_id[purge_id]
+ self.hs.get_reactor().callLater(24 * 3600, clear_purge)
+
+ def get_purge_status(self, purge_id):
+ """Get the current status of an active purge
+
+ Args:
+ purge_id (str): purge_id returned by start_purge_history
+
+ Returns:
+ PurgeStatus|None
+ """
+ return self._purges_by_id.get(purge_id)
+
+ @defer.inlineCallbacks
+ def get_messages(self, requester, room_id=None, pagin_config=None,
+ as_client_event=True, event_filter=None):
+ """Get messages in a room.
+
+ Args:
+ requester (Requester): The user requesting messages.
+ room_id (str): The room they want messages from.
+ pagin_config (synapse.api.streams.PaginationConfig): The pagination
+ config rules to apply, if any.
+ as_client_event (bool): True to get events in client-server format.
+ event_filter (Filter): Filter to apply to results or None
+ Returns:
+ dict: Pagination API results
+ """
+ user_id = requester.user.to_string()
+
+ if pagin_config.from_token:
+ room_token = pagin_config.from_token.room_key
+ else:
+ pagin_config.from_token = (
+ yield self.hs.get_event_sources().get_current_token_for_room(
+ room_id=room_id
+ )
+ )
+ room_token = pagin_config.from_token.room_key
+
+ room_token = RoomStreamToken.parse(room_token)
+
+ pagin_config.from_token = pagin_config.from_token.copy_and_replace(
+ "room_key", str(room_token)
+ )
+
+ source_config = pagin_config.get_source_config("room")
+
+ with (yield self.pagination_lock.read(room_id)):
+ membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
+
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
+ )
+
+ events, next_key = yield self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=source_config.from_key,
+ to_key=source_config.to_key,
+ direction=source_config.direction,
+ limit=source_config.limit,
+ event_filter=event_filter,
+ )
+
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
+ )
+
+ if not events:
+ defer.returnValue({
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ })
+
+ if event_filter:
+ events = event_filter.filter(events)
+
+ events = yield filter_events_for_client(
+ self.store,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ time_now = self.clock.time_msec()
+
+ chunk = {
+ "chunk": [
+ serialize_event(e, time_now, as_client_event)
+ for e in events
+ ],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
+
+ defer.returnValue(chunk)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index c7c0b0a1e2..3732830194 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -22,41 +22,44 @@ The methods that define policy are:
- should_notify
"""
-from twisted.internet import defer, reactor
+import logging
from contextlib import contextmanager
-from synapse.api.errors import SynapseError
+from six import iteritems, itervalues
+
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
from synapse.api.constants import PresenceState
+from synapse.api.errors import SynapseError
+from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
-
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import UserID, get_domain_from_id
from synapse.util.async import Linearizer
-from synapse.util.logcontext import preserve_fn
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-from synapse.types import UserID, get_domain_from_id
-import synapse.metrics
-
-import logging
-
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-notified_presence_counter = metrics.register_counter("notified_presence")
-federation_presence_out_counter = metrics.register_counter("federation_presence_out")
-presence_updates_counter = metrics.register_counter("presence_updates")
-timers_fired_counter = metrics.register_counter("timers_fired")
-federation_presence_counter = metrics.register_counter("federation_presence")
-bump_active_time_counter = metrics.register_counter("bump_active_time")
+notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
+federation_presence_out_counter = Counter(
+ "synapse_handler_presence_federation_presence_out", "")
+presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
+timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
+federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
-get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
-notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
-state_transition_counter = metrics.register_counter(
- "state_transition", labels=["from", "to"]
+notify_reason_counter = Counter(
+ "synapse_handler_presence_notify_reason", "", ["reason"])
+state_transition_counter = Counter(
+ "synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -87,35 +90,40 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
- self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender()
-
self.state = hs.get_state_handler()
- self.replication.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
- self.replication.register_edu_handler(
+ federation_registry.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
@@ -136,8 +144,9 @@ class PresenceHandler(object):
for state in active_presence
}
- metrics.register_callback(
- "user_to_current_state_size", lambda: len(self.user_to_current_state)
+ LaterGauge(
+ "synapse_handlers_presence_user_to_current_state_size", "", [],
+ lambda: len(self.user_to_current_state)
)
now = self.clock.time_msec()
@@ -169,7 +178,7 @@ class PresenceHandler(object):
# have not yet been persisted
self.unpersisted_users_changes = set()
- reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown)
self.serial_to_user = {}
self._next_serial = 1
@@ -207,7 +216,8 @@ class PresenceHandler(object):
60 * 1000,
)
- metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
+ LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
+ lambda: len(self.wheel_timer))
@defer.inlineCallbacks
def _on_shutdown(self):
@@ -254,6 +264,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
+ def _update_states_and_catch_exception(self, new_states):
+ try:
+ res = yield self._update_states(new_states)
+ defer.returnValue(res)
+ except Exception:
+ logger.exception("Error updating presence")
+
+ @defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -302,11 +320,11 @@ class PresenceHandler(object):
# TODO: We should probably ensure there are no races hereafter
- presence_updates_counter.inc_by(len(new_states))
+ presence_updates_counter.inc(len(new_states))
if to_notify:
- notified_presence_counter.inc_by(len(to_notify))
- yield self._persist_and_notify(to_notify.values())
+ notified_presence_counter.inc(len(to_notify))
+ yield self._persist_and_notify(list(to_notify.values()))
self.unpersisted_users_changes |= set(s.user_id for s in new_states)
self.unpersisted_users_changes -= set(to_notify.keys())
@@ -316,7 +334,7 @@ class PresenceHandler(object):
if user_id not in to_notify
}
if to_federation_ping:
- federation_presence_out_counter.inc_by(len(to_federation_ping))
+ federation_presence_out_counter.inc(len(to_federation_ping))
self._push_to_remotes(to_federation_ping.values())
@@ -354,7 +372,7 @@ class PresenceHandler(object):
for user_id in users_to_check
]
- timers_fired_counter.inc_by(len(states))
+ timers_fired_counter.inc(len(states))
changes = handle_timeouts(
states,
@@ -363,8 +381,8 @@ class PresenceHandler(object):
now=now,
)
- preserve_fn(self._update_states)(changes)
- except:
+ run_in_background(self._update_states_and_catch_exception, changes)
+ except Exception:
logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks
@@ -421,20 +439,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def _end():
- if affect_presence:
+ try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(),
)])
+ except Exception:
+ logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
try:
yield
finally:
- preserve_fn(_end)()
+ if affect_presence:
+ run_in_background(_end)
defer.returnValue(_user_syncing())
@@ -452,61 +473,6 @@ class PresenceHandler(object):
return syncing_user_ids
@defer.inlineCallbacks
- def update_external_syncs(self, process_id, syncing_user_ids):
- """Update the syncing users for an external process
-
- Args:
- process_id(str): An identifier for the process the users are
- syncing against. This allows synapse to process updates
- as user start and stop syncing against a given process.
- syncing_user_ids(set(str)): The set of user_ids that are
- currently syncing on that server.
- """
-
- # Grab the previous list of user_ids that were syncing on that process
- prev_syncing_user_ids = (
- self.external_process_to_current_syncs.get(process_id, set())
- )
- # Grab the current presence state for both the users that are syncing
- # now and the users that were syncing before this update.
- prev_states = yield self.current_state_for_users(
- syncing_user_ids | prev_syncing_user_ids
- )
- updates = []
- time_now_ms = self.clock.time_msec()
-
- # For each new user that is syncing check if we need to mark them as
- # being online.
- for new_user_id in syncing_user_ids - prev_syncing_user_ids:
- prev_state = prev_states[new_user_id]
- if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=time_now_ms,
- last_user_sync_ts=time_now_ms,
- ))
- else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- # For each user that is still syncing or stopped syncing update the
- # last sync time so that we will correctly apply the grace period when
- # they stop syncing.
- for old_user_id in prev_syncing_user_ids:
- prev_state = prev_states[old_user_id]
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- yield self._update_states(updates)
-
- # Update the last updated time for the process. We expire the entries
- # if we don't receive an update in the given timeframe.
- self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- self.external_process_to_current_syncs[process_id] = syncing_user_ids
-
- @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
"""Update the syncing users for an external process as a delta.
@@ -569,7 +535,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
)
- for prev_state in prev_states.itervalues()
+ for prev_state in itervalues(prev_states)
])
self.external_process_last_updated_ms.pop(process_id, None)
@@ -592,14 +558,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -695,7 +661,7 @@ class PresenceHandler(object):
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
- federation_presence_counter.inc_by(len(updates))
+ federation_presence_counter.inc(len(updates))
yield self._update_states(updates)
@defer.inlineCallbacks
@@ -720,7 +686,7 @@ class PresenceHandler(object):
"""
updates = yield self.current_state_for_users(target_user_ids)
- updates = updates.values()
+ updates = list(updates.values())
for user_id in set(target_user_ids) - set(u.user_id for u in updates):
updates.append(UserPresenceState.default(user_id))
@@ -786,11 +752,11 @@ class PresenceHandler(object):
self._push_to_remotes([state])
else:
user_ids = yield self.store.get_users_in_room(room_id)
- user_ids = filter(self.is_mine_id, user_ids)
+ user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
- self._push_to_remotes(states.values())
+ self._push_to_remotes(list(states.values()))
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@@ -970,28 +936,28 @@ def should_notify(old_state, new_state):
return False
if old_state.status_msg != new_state.status_msg:
- notify_reason_counter.inc("status_msg_change")
+ notify_reason_counter.labels("status_msg_change").inc()
return True
if old_state.state != new_state.state:
- notify_reason_counter.inc("state_change")
- state_transition_counter.inc(old_state.state, new_state.state)
+ notify_reason_counter.labels("state_change").inc()
+ state_transition_counter.labels(old_state.state, new_state.state).inc()
return True
if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active:
- notify_reason_counter.inc("current_active_change")
+ notify_reason_counter.labels("current_active_change").inc()
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
- notify_reason_counter.inc("last_active_change_online")
+ notify_reason_counter.labels("last_active_change_online").inc()
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
- notify_reason_counter.inc("last_active_change_not_online")
+ notify_reason_counter.labels("last_active_change_not_online").inc()
return True
return False
@@ -1065,14 +1031,14 @@ class PresenceEventSource(object):
if changed is not None and len(changed) < 500:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
- get_updates_counter.inc("stream")
+ get_updates_counter.labels("stream").inc()
for other_user_id in changed:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
- get_updates_counter.inc("full")
+ get_updates_counter.labels("full").inc()
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
@@ -1084,10 +1050,10 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((updates.values(), max_token))
+ defer.returnValue((list(updates.values()), max_token))
else:
defer.returnValue(([
- s for s in updates.itervalues()
+ s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE
], max_token))
@@ -1145,7 +1111,7 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
if new_state:
changes[state.user_id] = new_state
- return changes.values()
+ return list(changes.values())
def handle_timeout(state, is_mine, syncing_user_ids, now):
@@ -1199,7 +1165,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
)
changed = True
else:
- # We expect to be poked occaisonally by the other side.
+ # We expect to be poked occasionally by the other side.
# This is to protect against forgetful/buggy servers, so that
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
@@ -1344,11 +1310,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in room_ids_to_states.iteritems():
+ for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in users_to_states.iteritems():
+ for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 7abee98dea..859f6d2b2e 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,25 +17,88 @@ import logging
from twisted.internet import defer
-import synapse.types
-from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID
-from ._base import BaseHandler
+from synapse.api.errors import AuthError, CodeMessageException, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler):
+ PROFILE_UPDATE_MS = 60 * 1000
+ PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(ProfileHandler, self).__init__(hs)
- self.federation = hs.get_replication_layer()
- self.federation.register_query_handler(
+ self.federation = hs.get_federation_client()
+ hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query
)
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ )
+
+ @defer.inlineCallbacks
+ def get_profile(self, user_id):
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ try:
+ result = yield self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ defer.returnValue(result)
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get displayname")
+
+ raise
+
+ @defer.inlineCallbacks
+ def get_profile_from_cache(self, user_id):
+ """Get the profile information from our local cache. If the user is
+ ours then the profile information will always be corect. Otherwise,
+ it may be out of date/missing.
+ """
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ profile = yield self.store.get_from_remote_profile_cache(user_id)
+ defer.returnValue(profile or {})
+
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
@@ -60,7 +123,7 @@ class ProfileHandler(BaseHandler):
logger.exception("Failed to get displayname")
raise
- except:
+ except Exception:
logger.exception("Failed to get displayname")
else:
defer.returnValue(result["displayname"])
@@ -82,7 +145,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -107,7 +176,7 @@ class ProfileHandler(BaseHandler):
if e.code != 404:
logger.exception("Failed to get avatar_url")
raise
- except:
+ except Exception:
logger.exception("Failed to get avatar_url")
defer.returnValue(result["avatar_url"])
@@ -126,7 +195,13 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
- yield self._update_join_states(requester)
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
+
+ yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -151,28 +226,24 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response)
@defer.inlineCallbacks
- def _update_join_states(self, requester):
- user = requester.user
- if not self.hs.is_mine(user):
+ def _update_join_states(self, requester, target_user):
+ if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(
- user.to_string(),
+ target_user.to_string(),
)
for room_id in room_ids:
- handler = self.hs.get_handlers().room_member_handler
+ handler = self.hs.get_room_member_handler()
try:
- # Assume the user isn't a guest because we don't let guests set
- # profile or avatar data.
- # XXX why are we recreating `requester` here for each room?
- # what was wrong with the `requester` we were passed?
- requester = synapse.types.create_requester(user)
+ # Assume the target_user isn't a guest,
+ # because we don't let guests set profile or avatar data.
yield handler.update_membership(
requester,
- user,
+ target_user,
room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
@@ -182,3 +253,44 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s",
room_id, str(e.message)
)
+
+ def _update_remote_profile_cache(self):
+ """Called periodically to check profiles of remote users we haven't
+ checked in a while.
+ """
+ entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
+ )
+
+ for user_id, displayname, avatar_url in entries:
+ is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ user_id,
+ )
+ if not is_subscribed:
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+ continue
+
+ try:
+ profile = yield self.federation.make_query(
+ destination=get_domain_from_id(user_id),
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ except Exception:
+ logger.exception("Failed to get avatar_url")
+
+ yield self.store.update_remote_profile_cache(
+ user_id, displayname, avatar_url
+ )
+ continue
+
+ new_name = profile.get("displayname")
+ new_avatar = profile.get("avatar_url")
+
+ # We always hit update to update the last_check timestamp
+ yield self.store.update_remote_profile_cache(
+ user_id, new_name, new_avatar
+ )
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index b5b0303d54..995460f82a 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseHandler
+import logging
from twisted.internet import defer
from synapse.util.async import Linearizer
-import logging
+from ._base import BaseHandler
+
logger = logging.getLogger(__name__)
@@ -41,9 +42,9 @@ class ReadMarkerHandler(BaseHandler):
"""
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
-
- existing_read_marker = account_data.get("m.fully_read", None)
+ existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+ user_id, room_id, "m.fully_read",
+ )
should_update = True
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e1cd3a48e9..cb905a3903 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,16 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from ._base import BaseHandler
+import logging
from twisted.internet import defer
-from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
+from synapse.util import logcontext
+from synapse.util.logcontext import PreserveLoggingContext
-import logging
-
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,7 +33,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler(
+ hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
@@ -59,6 +58,8 @@ class ReceiptsHandler(BaseHandler):
is_new = yield self._handle_new_receipts([receipt])
if is_new:
+ # fire off a process in the background to send the receipt to
+ # remote servers
self._push_remotes([receipt])
@defer.inlineCallbacks
@@ -126,42 +127,46 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
+ @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
- # TODO: Some of this stuff should be coallesced.
- for receipt in receipts:
- room_id = receipt["room_id"]
- receipt_type = receipt["receipt_type"]
- user_id = receipt["user_id"]
- event_ids = receipt["event_ids"]
- data = receipt["data"]
-
- users = yield self.state.get_current_user_in_room(room_id)
- remotedomains = set(get_domain_from_id(u) for u in users)
- remotedomains = remotedomains.copy()
- remotedomains.discard(self.server_name)
-
- logger.debug("Sending receipt to: %r", remotedomains)
-
- for domain in remotedomains:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.receipt",
- content={
- room_id: {
- receipt_type: {
- user_id: {
- "event_ids": event_ids,
- "data": data,
+ try:
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ remotedomains = set(get_domain_from_id(u) for u in users)
+ remotedomains = remotedomains.copy()
+ remotedomains.discard(self.server_name)
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
}
- }
+ },
},
- },
- key=(room_id, receipt_type, user_id),
- )
+ key=(room_id, receipt_type, user_id),
+ )
+ except Exception:
+ logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ee3a2269a8..7caff0cbc8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,16 +15,22 @@
"""Contains functions for registering clients."""
import logging
-import urllib
from twisted.internet import defer
+from synapse import types
from synapse.api.errors import (
- AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
+ AuthError,
+ Codes,
+ InvalidCaptchaError,
+ RegistrationError,
+ SynapseError,
)
from synapse.http.client import CaptchaServerHttpClient
-from synapse.types import UserID
-from synapse.util.async import run_on_reactor
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.util.async import Linearizer
+from synapse.util.threepids import check_3pid_allowed
+
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -33,24 +39,35 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self.profile_handler = hs.get_profile_handler()
+ self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
+ self._generate_user_id_linearizer = Linearizer(
+ name="_generate_user_id_linearizer",
+ )
+ self._server_notices_mxid = hs.config.server_notices_mxid
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
- yield run_on_reactor()
-
- if urllib.quote(localpart.encode('utf-8')) != localpart:
+ if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
- "User ID can only contain characters a-z, 0-9, or '_-./'",
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
Codes.INVALID_USERNAME
)
@@ -80,7 +97,7 @@ class RegistrationHandler(BaseHandler):
"A different user ID has already been registered for this session",
)
- yield self.check_user_id_not_appservice_exclusive(user_id)
+ self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
@@ -127,10 +144,9 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
password_hash = None
if password:
- password_hash = self.auth_handler().hash(password)
+ password_hash = yield self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -165,6 +181,13 @@ class RegistrationHandler(BaseHandler):
),
admin=admin,
)
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
else:
# autogen a sequential user ID
attempts = 0
@@ -192,10 +215,17 @@ class RegistrationHandler(BaseHandler):
token = None
attempts += 1
+ # auto-join the user to any rooms we're supposed to dump them into
+ fake_requester = create_requester(user_id)
+ for r in self.hs.config.auto_join_rooms:
+ try:
+ yield self._join_user_to_room(fake_requester, r)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
# rather than there being consistent matrix-wide ones, so we don't.
-
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -253,11 +283,10 @@ class RegistrationHandler(BaseHandler):
"""
Registers email_id as SAML2 Based Auth.
"""
- if urllib.quote(localpart) != localpart:
+ if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
- "User ID must only contain characters which do not"
- " require URL encoding."
+ "User ID can only contain characters a-z, 0-9, or '=_-./'",
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -286,12 +315,12 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating theeepidcred sid %s on id server %s",
+ logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
- except:
+ except Exception:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@@ -300,6 +329,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
+ if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+ raise RegistrationError(
+ 403, "Third party identifier is not allowed"
+ )
+
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
@@ -314,6 +348,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id)
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ # don't allow people to register the server notices mxid
+ if self._server_notices_mxid is not None:
+ if user_id == self._server_notices_mxid:
+ raise SynapseError(
+ 400, "This user ID is reserved.",
+ errcode=Codes.EXCLUSIVE
+ )
+
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
@@ -332,9 +374,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
- self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
- )
+ with (yield self._generate_user_id_linearizer.queue(())):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
id = self._next_generated_user_id
self._next_generated_user_id += 1
@@ -391,8 +435,6 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield run_on_reactor()
-
if localpart is None:
raise SynapseError(400, "Request must include user id")
@@ -418,13 +460,12 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart,
)
else:
- yield self.store.user_delete_access_tokens(user_id=user_id)
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
- profile_handler = self.hs.get_handlers().profile_handler
- yield profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True,
)
@@ -434,16 +475,59 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
- def guest_access_token_for(self, medium, address, inviter_user_id):
+ def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
- defer.returnValue(access_token)
+ user_info = yield self.auth.get_user_by_access_token(
+ access_token
+ )
+
+ defer.returnValue((user_info["user"].to_string(), access_token))
- _, access_token = yield self.register(
+ user_id, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
- defer.returnValue(access_token)
+
+ defer.returnValue((user_id, access_token))
+
+ @defer.inlineCallbacks
+ def _join_user_to_room(self, requester, room_identifier):
+ room_id = None
+ room_member_handler = self.hs.get_room_member_handler()
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = (
+ yield room_member_handler.lookup_room_alias(room_alias)
+ )
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(400, "%s was not legal room ID or room alias" % (
+ room_identifier,
+ ))
+
+ yield room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5698d28088..6150b7e226 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,23 +15,20 @@
# limitations under the License.
"""Contains functions for performing events on rooms."""
-from twisted.internet import defer
+import logging
+import math
+import string
+from collections import OrderedDict
-from ._base import BaseHandler
+from twisted.internet import defer
-from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
-from synapse.api.constants import (
- EventTypes, JoinRules, RoomCreationPreset
-)
-from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.types import RoomAlias, RoomID, RoomStreamToken, UserID
from synapse.util import stringutils
from synapse.visibility import filter_events_for_client
-from collections import OrderedDict
-
-import logging
-import math
-import string
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -60,21 +58,43 @@ class RoomCreationHandler(BaseHandler):
},
}
+ def __init__(self, hs):
+ super(RoomCreationHandler, self).__init__(hs)
+
+ self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
+
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True):
+ def create_room(self, requester, config, ratelimit=True,
+ creator_join_profile=None):
""" Creates a new room.
Args:
- requester (Requester): The user who requested the room creation.
+ requester (synapse.types.Requester):
+ The user who requested the room creation.
config (dict) : A dict of configuration options.
+ ratelimit (bool): set to False to disable the rate limiter
+
+ creator_join_profile (dict|None):
+ Set to override the displayname and avatar for the creating
+ user in this room. If unset, displayname and avatar will be
+ derived from the user's profile. If set, should contain the
+ values to go in the body of the 'join' event (typically
+ `avatar_url` and/or `displayname`.
+
Returns:
- The new room ID.
+ Deferred[dict]:
+ a dict containing the keys `room_id` and, if an alias was
+ requested, `room_alias`.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
"""
user_id = requester.user.to_string()
+ if not self.spam_checker.user_may_create_room(user_id):
+ raise SynapseError(403, "You are not permitted to create rooms")
+
if ratelimit:
yield self.ratelimit(requester)
@@ -83,7 +103,7 @@ class RoomCreationHandler(BaseHandler):
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias.create(
+ room_alias = RoomAlias(
config["room_alias_name"],
self.hs.hostname,
)
@@ -92,7 +112,11 @@ class RoomCreationHandler(BaseHandler):
)
if mapping:
- raise SynapseError(400, "Room alias already taken")
+ raise SynapseError(
+ 400,
+ "Room alias already taken",
+ Codes.ROOM_IN_USE
+ )
else:
room_alias = None
@@ -100,9 +124,13 @@ class RoomCreationHandler(BaseHandler):
for i in invite_list:
try:
UserID.from_string(i)
- except:
+ except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ yield self.event_creation_handler.assert_accepted_privacy_policy(
+ requester,
+ )
+
invite_3pid_list = config.get("invite_3pid", [])
visibility = config.get("visibility", None)
@@ -115,7 +143,7 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID.create(
+ gen_room_id = RoomID(
random_string,
self.hs.hostname,
)
@@ -155,25 +183,24 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
- msg_handler = self.hs.get_handlers().message_handler
- room_member_handler = self.hs.get_handlers().room_member_handler
+ room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room(
requester,
room_id,
- msg_handler,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override", {})
+ power_level_content_override=config.get("power_level_content_override", {}),
+ creator_join_profile=creator_join_profile,
)
if "name" in config:
name = config["name"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -186,7 +213,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -197,12 +224,12 @@ class RoomCreationHandler(BaseHandler):
},
ratelimit=False)
- content = {}
- is_direct = config.get("is_direct", None)
- if is_direct:
- content["is_direct"] = is_direct
-
for invitee in invite_list:
+ content = {}
+ is_direct = config.get("is_direct", None)
+ if is_direct:
+ content["is_direct"] = is_direct
+
yield room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -216,7 +243,7 @@ class RoomCreationHandler(BaseHandler):
id_server = invite_3pid["id_server"]
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
+ yield self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -241,7 +268,6 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
- msg_handler,
room_member_handler,
preset_config,
invite_list,
@@ -249,6 +275,7 @@ class RoomCreationHandler(BaseHandler):
creation_content,
room_alias,
power_level_content_override,
+ creator_join_profile,
):
def create(etype, content, **kwargs):
e = {
@@ -264,7 +291,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False
@@ -292,6 +319,7 @@ class RoomCreationHandler(BaseHandler):
room_id,
"join",
ratelimit=False,
+ content=creator_join_profile,
)
# We treat the power levels override specially as this needs to be one
@@ -367,7 +395,11 @@ class RoomCreationHandler(BaseHandler):
)
-class RoomContextHandler(BaseHandler):
+class RoomContextHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit):
"""Retrieves events, pagination tokens and state around a given event
@@ -428,7 +460,7 @@ class RoomContextHandler(BaseHandler):
state = yield self.store.get_state_for_events(
[last_event_id], None
)
- results["state"] = state[last_event_id].values()
+ results["state"] = list(state[last_event_id].values())
results["start"] = now_token.copy_and_replace(
"room_key", results["start"]
@@ -468,12 +500,9 @@ class RoomEventSource(object):
user.to_string()
)
if app_service:
- events, end_key = yield self.store.get_appservice_room_stream(
- service=app_service,
- from_key=from_key,
- to_key=to_key,
- limit=limit,
- )
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
else:
room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 516cd9a6ac..828229f5c3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,23 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
+from collections import namedtuple
-from ._base import BaseHandler
+from six import iteritems
+from six.moves import range
+
+import msgpack
+from unpaddedbase64 import decode_base64, encode_base64
+
+from twisted.internet import defer
-from synapse.api.constants import (
- EventTypes, JoinRules,
-)
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.types import ThirdPartyInstanceID
from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
-from synapse.types import ThirdPartyInstanceID
-
-from collections import namedtuple
-from unpaddedbase64 import encode_base64, decode_base64
-import logging
-import msgpack
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -37,18 +38,19 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
-EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
+EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache(hs)
- self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+ self.response_cache = ResponseCache(hs, "room_list")
+ self.remote_response_cache = ResponseCache(hs, "remote_room_list",
+ timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -70,25 +72,22 @@ class RoomListHandler(BaseHandler):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
+ logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
key = (limit, since_token, network_tuple)
- result = self.response_cache.get(key)
- if not result:
- result = self.response_cache.set(
- key,
- self._get_public_room_list(
- limit, since_token, network_tuple=network_tuple
- )
- )
- return result
+ return self.response_cache.wrap(
+ key,
+ self._get_public_room_list,
+ limit, since_token, network_tuple=network_tuple,
+ )
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@@ -149,6 +148,8 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
+ logger.info("Getting ordering for %i rooms since %s",
+ len(room_ids), stream_token)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -176,34 +177,43 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
- # Actually generate the entries. _append_room_entry_to_chunk will append to
- # chunk but will stop if len(chunk) > limit
- chunk = []
- if limit and not search_filter:
+ logger.info("After sorting and filtering, %i rooms remain",
+ len(rooms_to_scan))
+
+ # _append_room_entry_to_chunk will append to chunk but will stop if
+ # len(chunk) > limit
+ #
+ # Normally we will generate enough results on the first iteration here,
+ # but if there is a search filter, _append_room_entry_to_chunk may
+ # filter some results out, in which case we loop again.
+ #
+ # We don't want to scan over the entire range either as that
+ # would potentially waste a lot of work.
+ #
+ # XXX if there is no limit, we may end up DoSing the server with
+ # calls to get_current_state_ids for every single room on the
+ # server. Surely we should cap this somehow?
+ #
+ if limit:
step = limit + 1
- for i in xrange(0, len(rooms_to_scan), step):
- # We iterate here because the vast majority of cases we'll stop
- # at first iteration, but occaisonally _append_room_entry_to_chunk
- # won't append to the chunk and so we need to loop again.
- # We don't want to scan over the entire range either as that
- # would potentially waste a lot of work.
- yield concurrently_execute(
- lambda r: self._append_room_entry_to_chunk(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter
- ),
- rooms_to_scan[i:i + step], 10
- )
- if len(chunk) >= limit + 1:
- break
else:
+ # step cannot be zero
+ step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
+
+ chunk = []
+ for i in range(0, len(rooms_to_scan), step):
+ batch = rooms_to_scan[i:i + step]
+ logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
- rooms_to_scan, 5
+ batch, 5,
)
+ logger.info("Now %i rooms in result", len(chunk))
+ if len(chunk) >= limit + 1:
+ break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
@@ -276,13 +286,14 @@ class RoomListHandler(BaseHandler):
# We've already got enough, so lets just drop it.
return
- result = yield self._generate_room_entry(room_id, num_joined_users)
+ result = yield self.generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def _generate_room_entry(self, room_id, num_joined_users, cache_context):
+ def generate_room_entry(self, room_id, num_joined_users, cache_context,
+ with_alias=True, allow_private=False):
"""Returns the entry for a room
"""
result = {
@@ -295,7 +306,7 @@ class RoomListHandler(BaseHandler):
)
event_map = yield self.store.get_events([
- event_id for key, event_id in current_state_ids.iteritems()
+ event_id for key, event_id in iteritems(current_state_ids)
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
@@ -316,14 +327,15 @@ class RoomListHandler(BaseHandler):
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
- if join_rule and join_rule != JoinRules.PUBLIC:
+ if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
- aliases = yield self.store.get_aliases_for_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- if aliases:
- result["aliases"] = aliases
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
@@ -391,7 +403,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
- repl_layer = self.hs.get_replication_layer()
+ repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
@@ -404,18 +416,14 @@ class RoomListHandler(BaseHandler):
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
- result = self.remote_response_cache.get(key)
- if not result:
- result = self.remote_response_cache.set(
- key,
- repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
- )
- return result
+ return self.remote_response_cache.wrap(
+ key,
+ repl_layer.get_public_rooms,
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1b8dfa8254..0d4a3f4677 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,63 +14,161 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import abc
import logging
+from six.moves import http_client
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
-from twisted.internet import defer
from unpaddedbase64 import decode_base64
+from twisted.internet import defer
+
+import synapse.server
import synapse.types
-from synapse.api.constants import (
- EventTypes, Membership,
-)
-from synapse.api.errors import AuthError, SynapseError, Codes
-from synapse.types import UserID, RoomID
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.types import RoomID, UserID
from synapse.util.async import Linearizer
-from synapse.util.distributor import user_left_room, user_joined_room
-from ._base import BaseHandler
+from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
+ __metaclass__ = abc.ABCMeta
+
def __init__(self, hs):
- super(RoomMemberHandler, self).__init__(hs)
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.state_handler = hs.get_state_handler()
+ self.config = hs.config
+ self.simple_http_client = hs.get_simple_http_client()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.directory_handler = hs.get_handlers().directory_handler
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.profile_handler = hs.get_profile_handler()
+ self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
+ self.spam_checker = hs.get_spam_checker()
+ self._server_notices_mxid = self.config.server_notices_mxid
- self.distributor = hs.get_distributor()
- self.distributor.declare("user_joined_room")
- self.distributor.declare("user_left_room")
+ @abc.abstractmethod
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Try and join a room that this server is not in
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers that can be used
+ to join via.
+ room_id (str): Room that we are trying to join
+ user (UserID): User who is trying to join
+ content (dict): A dict that should be used as the content of the
+ join event.
+
+ Returns:
+ Deferred
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+ """Attempt to reject an invite for a room this server is not in. If we
+ fail to do so we locally mark the invite as rejected.
+
+ Args:
+ requester (Requester)
+ remote_room_hosts (list[str]): List of servers to use to try and
+ reject invite
+ room_id (str)
+ target (UserID): The user rejecting the invite
+
+ Returns:
+ Deferred[dict]: A dictionary to be returned to the client, may
+ include event_id etc, or nothing if we locally rejected
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Get a guest access token for a 3PID, creating a guest account if
+ one doesn't already exist.
+
+ Args:
+ requester (Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_joined_room(self, target, room_id):
+ """Notifies distributor on master process that the user has joined the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _user_left_room(self, target, room_id):
+ """Notifies distributor on master process that the user has left the
+ room.
+
+ Args:
+ target (UserID)
+ room_id (str)
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
- prev_event_ids,
+ prev_events_and_hashes,
txn_id=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
- msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
requester,
{
"type": EventTypes.Member,
@@ -83,16 +182,18 @@ class RoomMemberHandler(BaseHandler):
},
token_id=requester.access_token_id,
txn_id=txn_id,
- prev_event_ids=prev_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield msg_handler.deduplicate_state_event(event, context)
+ duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield msg_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -100,7 +201,9 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event_id = context.prev_state_ids.get(
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()),
None
)
@@ -114,33 +217,16 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target, room_id)
+ yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target, room_id)
+ yield self._user_left_room(target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks
- def remote_join(self, remote_room_hosts, room_id, user, content):
- if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
-
- # We don't do an auth check if we are doing an invite
- # join dance for now, since we're kinda implicitly checking
- # that we are allowed to join when we decide whether or not we
- # need to do the invite/join dance.
- yield self.hs.get_handlers().federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
- )
- yield user_joined_room(self.distributor, user, room_id)
-
- @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -186,14 +272,19 @@ class RoomMemberHandler(BaseHandler):
content_specified = bool(content)
if content is None:
content = {}
+ else:
+ # We do a copy here as we potentially change some keys
+ # later on.
+ content = dict(content)
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
+ # if this is a join with a 3pid signature, we may need to turn a 3pid
+ # invite into a normal invite before we can handle the join.
if third_party_signed is not None:
- replication = self.hs.get_replication_layer()
- yield replication.exchange_third_party_invite(
+ yield self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@@ -208,7 +299,51 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ if effective_membership_state == Membership.INVITE:
+ # block any attempts to invite the server notices mxid
+ if target.to_string() == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
+ block_invite = False
+
+ if (self._server_notices_mxid is not None and
+ requester.user.to_string() == self._server_notices_mxid):
+ # allow the server notices mxid to send invites
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+
+ if not is_requester_admin:
+ if self.config.block_non_admin_invites:
+ logger.info(
+ "Blocking invite: user is not admin and non-admin "
+ "invites disabled"
+ )
+ block_invite = True
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(), target.to_string(), room_id,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ block_invite = True
+
+ if block_invite:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ )
+
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ room_id,
+ )
+ latest_event_ids = (
+ event_id for (event_id, _, _) in prev_events_and_hashes
+ )
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
@@ -238,6 +373,20 @@ class RoomMemberHandler(BaseHandler):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
+ # we don't allow people to reject invites to the server notice
+ # room, but they can leave it once they are joined.
+ if (
+ old_membership == Membership.INVITE and
+ effective_membership_state == Membership.LEAVE
+ ):
+ is_blocked = yield self._is_server_notice_room(room_id)
+ if is_blocked:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "You cannot reject this invite",
+ errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
+ )
+
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
@@ -249,13 +398,13 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN
- profile = self.hs.get_handlers().profile_handler
+ profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
@@ -263,15 +412,15 @@ class RoomMemberHandler(BaseHandler):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self.remote_join(
- remote_room_hosts, room_id, target, content
+ ret = yield self._remote_join(
+ requester, remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = yield self.get_inviter(target.to_string(), room_id)
+ inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@@ -285,28 +434,10 @@ class RoomMemberHandler(BaseHandler):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- fed_handler = self.hs.get_handlers().federation_handler
- try:
- ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- target.to_string(),
- )
- defer.returnValue(ret)
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warn("Failed to reject invite: %s", e)
-
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
-
- defer.returnValue({})
+ res = yield self._remote_reject_invite(
+ requester, remote_room_hosts, room_id, target,
+ )
+ defer.returnValue(res)
res = yield self._local_membership_update(
requester=requester,
@@ -315,7 +446,7 @@ class RoomMemberHandler(BaseHandler):
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
- prev_event_ids=latest_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
content=content,
)
defer.returnValue(res)
@@ -361,14 +492,16 @@ class RoomMemberHandler(BaseHandler):
else:
requester = synapse.types.create_requester(target_user)
- message_handler = self.hs.get_handlers().message_handler
- prev_event = yield message_handler.deduplicate_state_event(event, context)
+ prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if prev_event is not None:
return
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
if event.membership == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(context.prev_state_ids)
+ guest_can_join = yield self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -379,7 +512,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield message_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -387,7 +520,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
- prev_member_event_id = context.prev_state_ids.get(
+ prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, event.state_key),
None
)
@@ -401,12 +534,12 @@ class RoomMemberHandler(BaseHandler):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield user_joined_room(self.distributor, target_user, room_id)
+ yield self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- user_left_room(self.distributor, target_user, room_id)
+ yield self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@@ -440,7 +573,7 @@ class RoomMemberHandler(BaseHandler):
Raises:
SynapseError if room alias could not be found.
"""
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
@@ -452,7 +585,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
- def get_inviter(self, user_id, room_id):
+ def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
@@ -471,6 +604,16 @@ class RoomMemberHandler(BaseHandler):
requester,
txn_id
):
+ if self.config.block_non_admin_invites:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
invitee = yield self._lookup_3pid(
id_server, medium, address
)
@@ -508,7 +651,7 @@ class RoomMemberHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
- data = yield self.hs.get_simple_http_client().get_json(
+ data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
@@ -519,7 +662,7 @@ class RoomMemberHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
- self.verify_any_signature(data, id_server)
+ yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
@@ -527,11 +670,11 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def verify_any_signature(self, data, server_hostname):
+ def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.hs.get_simple_http_client().get_json(
+ key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
@@ -556,7 +699,7 @@ class RoomMemberHandler(BaseHandler):
user,
txn_id
):
- room_state = yield self.hs.get_state_handler().get_current_state(room_id)
+ room_state = yield self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -591,6 +734,7 @@ class RoomMemberHandler(BaseHandler):
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
+ requester=requester,
id_server=id_server,
medium=medium,
address=address,
@@ -605,8 +749,7 @@ class RoomMemberHandler(BaseHandler):
)
)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -628,6 +771,7 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
+ requester,
id_server,
medium,
address,
@@ -644,6 +788,7 @@ class RoomMemberHandler(BaseHandler):
Asks an identity server for a third party invite.
Args:
+ requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
@@ -685,24 +830,20 @@ class RoomMemberHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url,
}
- if self.hs.config.invite_3pid_guest:
- registration_handler = self.hs.get_handlers().registration_handler
- guest_access_token = yield registration_handler.guest_access_token_for(
+ if self.config.invite_3pid_guest:
+ guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ requester=requester,
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
- guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
- guest_access_token
- )
-
invite_config.update({
"guest_access_token": guest_access_token,
- "guest_user_id": guest_user_info["user"].to_string(),
+ "guest_user_id": guest_user_id,
})
- data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
+ data = yield self.simple_http_client.post_urlencoded_get_json(
is_url,
invite_config
)
@@ -725,27 +866,6 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
- def forget(self, user, room_id):
- user_id = user.to_string()
-
- member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
- )
- membership = member.membership if member else None
-
- if membership is not None and membership not in [
- Membership.LEAVE, Membership.BAN
- ]:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
-
- if membership:
- yield self.store.forget(user_id, room_id)
-
- @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
@@ -766,3 +886,109 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(True)
defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notice_room(self, room_id):
+ if self._server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self._server_notices_mxid in user_ids)
+
+
+class RoomMemberMasterHandler(RoomMemberHandler):
+ def __init__(self, hs):
+ super(RoomMemberMasterHandler, self).__init__(hs)
+
+ self.distributor = hs.get_distributor()
+ self.distributor.declare("user_joined_room")
+ self.distributor.declare("user_left_room")
+
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ # filter ourselves out of remote_room_hosts: do_invite_join ignores it
+ # and if it is the only entry we'd like to return a 404 rather than a
+ # 500.
+
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ # We don't do an auth check if we are doing an invite
+ # join dance for now, since we're kinda implicitly checking
+ # that we are allowed to join when we decide whether or not we
+ # need to do the invite/join dance.
+ yield self.federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user.to_string(),
+ content,
+ )
+ yield self._user_joined_room(user, room_id)
+
+ @defer.inlineCallbacks
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ fed_handler = self.federation_handler
+ try:
+ ret = yield fed_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ target.to_string(),
+ )
+ defer.returnValue(ret)
+ except Exception as e:
+ # if we were unable to reject the exception, just mark
+ # it as rejected on our end and plough ahead.
+ #
+ # The 'except' clause is very broad, but we need to
+ # capture everything from DNS failures upwards
+ #
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ target.to_string(), room_id
+ )
+ defer.returnValue({})
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ rg = self.registration_handler
+ return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return user_joined_room(self.distributor, target, room_id)
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return user_left_room(self.distributor, target, room_id)
+
+ @defer.inlineCallbacks
+ def forget(self, user, room_id):
+ user_id = user.to_string()
+
+ member = yield self.state_handler.get_current_state(
+ room_id=room_id,
+ event_type=EventTypes.Member,
+ state_key=user_id
+ )
+ membership = member.membership if member else None
+
+ if membership is not None and membership not in [
+ Membership.LEAVE, Membership.BAN
+ ]:
+ raise SynapseError(400, "User %s in room %s" % (
+ user_id, room_id
+ ))
+
+ if membership:
+ yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
new file mode 100644
index 0000000000..22d8b4b0d3
--- /dev/null
+++ b/synapse/handlers/room_member_worker.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.room_member import RoomMemberHandler
+from synapse.replication.http.membership import (
+ get_or_register_3pid_guest,
+ notify_user_membership_change,
+ remote_join,
+ remote_reject_invite,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberWorkerHandler(RoomMemberHandler):
+ @defer.inlineCallbacks
+ def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ """Implements RoomMemberHandler._remote_join
+ """
+ if len(remote_room_hosts) == 0:
+ raise SynapseError(404, "No known servers")
+
+ ret = yield remote_join(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=user.to_string(),
+ content=content,
+ )
+
+ yield self._user_joined_room(user, room_id)
+
+ defer.returnValue(ret)
+
+ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ """Implements RoomMemberHandler._remote_reject_invite
+ """
+ return remote_reject_invite(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ remote_room_hosts=remote_room_hosts,
+ room_id=room_id,
+ user_id=target.to_string(),
+ )
+
+ def _user_joined_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_joined_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="joined",
+ )
+
+ def _user_left_room(self, target, room_id):
+ """Implements RoomMemberHandler._user_left_room
+ """
+ return notify_user_membership_change(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ user_id=target.to_string(),
+ room_id=room_id,
+ change="left",
+ )
+
+ def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+ """Implements RoomMemberHandler.get_or_register_3pid_guest
+ """
+ return get_or_register_3pid_guest(
+ self.simple_http_client,
+ host=self.config.worker_replication_host,
+ port=self.config.worker_replication_http_port,
+ requester=requester,
+ medium=medium,
+ address=address,
+ inviter_user_id=inviter_user_id,
+ )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index df75d70fac..69ae9731d5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -13,21 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import itertools
+import logging
-from ._base import BaseHandler
+from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import Membership, EventTypes
-from synapse.api.filtering import Filter
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
from synapse.events.utils import serialize_event
from synapse.visibility import filter_events_for_client
-from unpaddedbase64 import decode_base64, encode_base64
-
-import itertools
-import logging
-
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -61,9 +60,16 @@ class SearchHandler(BaseHandler):
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
- except:
+ except Exception:
raise SynapseError(400, "Invalid batch")
+ logger.info(
+ "Search batch properties: %r, %r, %r",
+ batch_group, batch_group_key, batch_token,
+ )
+
+ logger.info("Search content: %s", content)
+
try:
room_cat = content["search_categories"]["room_events"]
@@ -271,6 +277,8 @@ class SearchHandler(BaseHandler):
# We should never get here due to the guard earlier.
raise NotImplementedError()
+ logger.info("Found %d events to return", len(allowed_events))
+
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
@@ -282,6 +290,11 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit
)
+ logger.info(
+ "Context for search returned %d and %d events",
+ len(res["events_before"]), len(res["events_after"]),
+ )
+
res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"]
)
@@ -348,7 +361,7 @@ class SearchHandler(BaseHandler):
rooms = set(e.room_id for e in allowed_events)
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
- state_results[room_id] = state.values()
+ state_results[room_id] = list(state.values())
state_results.values()
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
new file mode 100644
index 0000000000..7ecdede4dc
--- /dev/null
+++ b/synapse/handlers/set_password.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError, SynapseError
+
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SetPasswordHandler(BaseHandler):
+ """Handler which deals with changing user account passwords"""
+ def __init__(self, hs):
+ super(SetPasswordHandler, self).__init__(hs)
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def set_password(self, user_id, newpassword, requester=None):
+ password_hash = yield self._auth_handler.hash(newpassword)
+
+ except_device_id = requester.device_id if requester else None
+ except_access_token_id = requester.access_token_id if requester else None
+
+ try:
+ yield self.store.user_set_password_hash(user_id, password_hash)
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+ raise e
+
+ # we want to log out all of the user's other sessions. First delete
+ # all his other devices.
+ yield self._device_handler.delete_all_devices_for_user(
+ user_id, except_device_id=except_device_id,
+ )
+
+ # and now delete any access tokens which weren't associated with
+ # devices (or were associated with this device).
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id,
+ )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 91c6c6be3c..c24e35362a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import Membership, EventTypes
+import collections
+import itertools
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.push.clientformat import format_push_rules_for_user
+from synapse.types import RoomStreamToken
from synapse.util.async import concurrently_execute
+from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
-from synapse.types import RoomStreamToken
-
-from twisted.internet import defer
-
-import collections
-import logging
-import itertools
logger = logging.getLogger(__name__)
@@ -52,6 +54,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
@@ -76,6 +79,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+ __bool__ = __nonzero__ # python3
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
@@ -95,6 +99,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
or self.state
or self.account_data
)
+ __bool__ = __nonzero__ # python3
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@@ -106,6 +111,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+ __bool__ = __nonzero__ # python3
+
+
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+ "join",
+ "invite",
+ "leave",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.join or self.invite or self.leave)
+ __bool__ = __nonzero__ # python3
+
+
+class DeviceLists(collections.namedtuple("DeviceLists", [
+ "changed", # list of user_ids whose devices may have changed
+ "left", # list of user_ids whose devices we no longer track
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.changed or self.left)
+ __bool__ = __nonzero__ # python3
class SyncResult(collections.namedtuple("SyncResult", [
@@ -116,9 +145,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have chanegd
+ "device_lists", # List of user_ids whose devices have changed
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
+ "groups",
])):
__slots__ = []
@@ -134,8 +164,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or
self.account_data or
self.to_device or
- self.device_lists
+ self.device_lists or
+ self.groups
)
+ __bool__ = __nonzero__ # python3
class SyncHandler(object):
@@ -146,7 +178,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(hs)
+ self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
@@ -157,15 +189,11 @@ class SyncHandler(object):
Returns:
A Deferred SyncResult.
"""
- result = self.response_cache.get(sync_config.request_key)
- if not result:
- result = self.response_cache.set(
- sync_config.request_key,
- self._wait_for_sync_for_user(
- sync_config, since_token, timeout, full_state
- )
- )
- return result
+ return self.response_cache.wrap(
+ sync_config.request_key,
+ self._wait_for_sync_for_user,
+ sync_config, since_token, timeout, full_state,
+ )
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -212,10 +240,10 @@ class SyncHandler(object):
defer.returnValue(rules)
@defer.inlineCallbacks
- def ephemeral_by_room(self, sync_config, now_token, since_token=None):
+ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
- sync_config (SyncConfig): The flags, filters and user for the sync.
+ sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
last synced.
@@ -225,10 +253,12 @@ class SyncHandler(object):
typing events for that room.
"""
+ sync_config = sync_result_builder.sync_config
+
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
- room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+ room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
@@ -247,7 +277,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -266,7 +296,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -290,10 +320,20 @@ class SyncHandler(object):
if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(itervalues(current_state_ids))
+
recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
recents,
+ always_include_ids=current_state_ids,
)
else:
recents = []
@@ -316,19 +356,41 @@ class SyncHandler(object):
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
- events, end_key = yield self.store.get_room_events_stream_for_room(
- room_id,
- limit=load_limit + 1,
- from_key=since_key,
- to_key=end_key,
- )
+ # If we have a since_key then we are trying to get any events
+ # that have happened since `since_key` up to `end_key`, so we
+ # can just use `get_room_events_stream_for_room`.
+ # Otherwise, we want to return the last N events in the room
+ # in toplogical ordering.
+ if since_key:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ else:
+ events, end_key = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ end_token=end_key,
+ )
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in loaded_recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(itervalues(current_state_ids))
+
loaded_recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
loaded_recents,
+ always_include_ids=current_state_ids,
)
loaded_recents.extend(recents)
recents = loaded_recents
@@ -381,7 +443,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- last_events, token = yield self.store.get_recent_events_for_room(
+ # FIXME this claims to get the state at a stream position, but
+ # get_recent_events_for_room operates by topo ordering. This therefore
+ # does not reliably give you the state at the given stream position.
+ # (https://github.com/matrix-org/synapse/issues/3305)
+ last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
)
@@ -475,11 +541,11 @@ class SyncHandler(object):
state = {}
if state_ids:
- state = yield self.store.get_events(state_ids.values())
+ state = yield self.store.get_events(list(state_ids.values()))
defer.returnValue({
(e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
+ for e in sync_config.filter_collection.filter_room_state(list(state.values()))
})
@defer.inlineCallbacks
@@ -522,10 +588,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+ else:
+ joined_room_ids = yield self.get_rooms_for_user_at(
+ user_id, now_token.room_stream_id,
+ )
+
sync_result_builder = SyncResultBuilder(
sync_config, full_state,
since_token=since_token,
now_token=now_token,
+ joined_room_ids=joined_room_ids,
)
account_data_by_room = yield self._generate_sync_entry_for_account_data(
@@ -535,7 +613,8 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users = res
+ newly_joined_rooms, newly_joined_users, _, _ = res
+ _, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
since_token is None and
@@ -549,17 +628,22 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
- sync_result_builder
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_users=newly_joined_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
)
device_id = sync_config.device_id
one_time_key_counts = {}
if device_id:
- user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id
)
+ yield self._generate_sync_entry_for_groups(sync_result_builder)
+
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -568,31 +652,103 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
+ groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
+ @measure_func("_generate_sync_entry_for_groups")
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_groups(self, sync_result_builder):
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+
+ if since_token and since_token.groups_key:
+ results = yield self.store.get_groups_changes_for_user(
+ user_id, since_token.groups_key, now_token.groups_key,
+ )
+ else:
+ results = yield self.store.get_all_groups_for_user(
+ user_id, now_token.groups_key,
+ )
+
+ invited = {}
+ joined = {}
+ left = {}
+ for result in results:
+ membership = result["membership"]
+ group_id = result["group_id"]
+ gtype = result["type"]
+ content = result["content"]
+
+ if membership == "join":
+ if gtype == "membership":
+ # TODO: Add profile
+ content.pop("membership", None)
+ joined[group_id] = content["content"]
+ else:
+ joined.setdefault(group_id, {})[gtype] = content
+ elif membership == "invite":
+ if gtype == "membership":
+ content.pop("membership", None)
+ invited[group_id] = content["content"]
+ else:
+ if gtype == "membership":
+ left[group_id] = content["content"]
+
+ sync_result_builder.groups = GroupsSyncResult(
+ join=joined,
+ invite=invited,
+ leave=left,
+ )
+
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder):
+ def _generate_sync_entry_for_device_list(self, sync_result_builder,
+ newly_joined_rooms, newly_joined_users,
+ newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
- room_ids = yield self.store.get_rooms_for_user(user_id)
-
- user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
- for other_user_id in changed:
- other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
- if room_ids.intersection(other_room_ids):
- user_ids_changed.add(other_user_id)
- defer.returnValue(user_ids_changed)
+ # TODO: Be more clever than this, i.e. remove users who we already
+ # share a room with?
+ for room_id in newly_joined_rooms:
+ joined_users = yield self.state.get_current_user_in_room(room_id)
+ newly_joined_users.update(joined_users)
+
+ for room_id in newly_left_rooms:
+ left_users = yield self.state.get_current_user_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # TODO: Check that these users are actually new, i.e. either they
+ # weren't in the previous sync *or* they left and rejoined.
+ changed.update(newly_joined_users)
+
+ if not changed and not newly_left_users:
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=newly_left_users,
+ ))
+
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
+
+ defer.returnValue(DeviceLists(
+ changed=users_who_share_room & changed,
+ left=set(newly_left_users) - users_who_share_room,
+ ))
else:
- defer.returnValue([])
+ defer.returnValue(DeviceLists(
+ changed=[],
+ left=[],
+ ))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -738,7 +894,7 @@ class SyncHandler(object):
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
- presence = {p.user_id: p for p in presence}.values()
+ presence = list({p.user_id: p for p in presence}.values())
presence = sync_config.filter_collection.filter_presence(
presence
@@ -756,8 +912,8 @@ class SyncHandler(object):
account_data_by_room(dict): Dictionary of per room account data
Returns:
- Deferred(tuple): Returns a 2-tuple of
- `(newly_joined_rooms, newly_joined_users)`
+ Deferred(tuple): Returns a 4-tuple of
+ `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -769,7 +925,7 @@ class SyncHandler(object):
ephemeral_by_room = {}
else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
- sync_result_builder.sync_config,
+ sync_result_builder,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
)
@@ -788,7 +944,7 @@ class SyncHandler(object):
)
if not tags_by_room:
logger.debug("no-oping sync")
- defer.returnValue(([], []))
+ defer.returnValue(([], [], [], []))
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id,
@@ -801,7 +957,7 @@ class SyncHandler(object):
if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
- room_entries, invited, newly_joined_rooms = res
+ room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
user_id, since_token.account_data_key,
@@ -809,6 +965,7 @@ class SyncHandler(object):
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res
+ newly_left_rooms = []
tags_by_room = yield self.store.get_tags_for_user(user_id)
@@ -829,17 +986,30 @@ class SyncHandler(object):
# Now we want to get any newly joined users
newly_joined_users = set()
+ newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.values()
+ joined_sync.timeline.events, itervalues(joined_sync.state)
)
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
newly_joined_users.add(event.state_key)
-
- defer.returnValue((newly_joined_rooms, newly_joined_users))
+ else:
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership == Membership.JOIN:
+ newly_left_users.add(event.state_key)
+
+ newly_left_users -= newly_joined_users
+
+ defer.returnValue((
+ newly_joined_rooms,
+ newly_joined_users,
+ newly_left_rooms,
+ newly_left_users,
+ ))
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
@@ -860,15 +1030,8 @@ class SyncHandler(object):
if rooms_changed:
defer.returnValue(True)
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
- for room_id in joined_room_ids:
+ for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True)
defer.returnValue(False)
@@ -883,7 +1046,13 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a tuple of the form:
- `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)`
+ `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)`
+
+ where:
+ room_entries is a list [RoomSyncResultBuilder]
+ invited_rooms is a list [InvitedSyncResult]
+ newly_joined rooms is a list[str] of room ids
+ newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -892,13 +1061,6 @@ class SyncHandler(object):
assert since_token
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- rooms = yield self.store.get_app_service_rooms(app_service)
- joined_room_ids = set(r.room_id for r in rooms)
- else:
- joined_room_ids = yield self.store.get_rooms_for_user(user_id)
-
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
@@ -909,16 +1071,29 @@ class SyncHandler(object):
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
+ newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in mem_change_events_by_room_id.items():
+ for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
- if room_id in joined_room_ids or has_join:
+
+ old_state_ids = None
+ if room_id in sync_result_builder.joined_room_ids and non_joins:
+ # Always include if the user (re)joined the room, especially
+ # important so that device list changes are calculated correctly.
+ # If there are non join member events, but we are still in the room,
+ # then the user must have left and joined
+ newly_joined_rooms.append(room_id)
+
+ # User is in the room so we don't need to do the invite/leave checks
+ continue
+
+ if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
@@ -929,12 +1104,33 @@ class SyncHandler(object):
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
- if room_id in joined_room_ids:
- continue
+ # If user is in the room then we don't need to do the invite/leave checks
+ if room_id in sync_result_builder.joined_room_ids:
+ continue
if not non_joins:
continue
+ # Check if we have left the room. This can either be because we were
+ # joined before *or* that we since joined and then left.
+ if events[-1].membership != Membership.JOIN:
+ if has_join:
+ newly_left_rooms.append(room_id)
+ else:
+ if not old_state_ids:
+ old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_mem_ev_id = old_state_ids.get(
+ (EventTypes.Member, user_id),
+ None,
+ )
+ old_mem_ev = None
+ if old_mem_ev_id:
+ old_mem_ev = yield self.store.get_event(
+ old_mem_ev_id, allow_none=True
+ )
+ if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
+ newly_left_rooms.append(room_id)
+
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
@@ -976,7 +1172,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
- room_ids=joined_room_ids,
+ room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
@@ -984,7 +1180,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
- for room_id in joined_room_ids:
+ for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
@@ -1012,7 +1208,7 @@ class SyncHandler(object):
upto_token=since_token,
))
- defer.returnValue((room_entries, invited, newly_joined_rooms))
+ defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1192,6 +1388,54 @@ class SyncHandler(object):
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
+ @defer.inlineCallbacks
+ def get_rooms_for_user_at(self, user_id, stream_ordering):
+ """Get set of joined rooms for a user at the given stream ordering.
+
+ The stream ordering *must* be recent, otherwise this may throw an
+ exception if older than a month. (This function is called with the
+ current token, which should be perfectly fine).
+
+ Args:
+ user_id (str)
+ stream_ordering (int)
+
+ ReturnValue:
+ Deferred[frozenset[str]]: Set of room_ids the user is in at given
+ stream_ordering.
+ """
+ joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
+ user_id,
+ )
+
+ joined_room_ids = set()
+
+ # We need to check that the stream ordering of the join for each room
+ # is before the stream_ordering asked for. This might not be the case
+ # if the user joins a room between us getting the current token and
+ # calling `get_rooms_for_user_with_stream_ordering`.
+ # If the membership's stream ordering is after the given stream
+ # ordering, we need to go and work out if the user was in the room
+ # before.
+ for room_id, membership_stream_ordering in joined_rooms:
+ if membership_stream_ordering <= stream_ordering:
+ joined_room_ids.add(room_id)
+ continue
+
+ logger.info("User joined room after current token: %s", room_id)
+
+ extrems = yield self.store.get_forward_extremeties_for_room(
+ room_id, stream_ordering,
+ )
+ users_in_room = yield self.state.get_current_user_in_room(
+ room_id, extrems,
+ )
+ if user_id in users_in_room:
+ joined_room_ids.add(room_id)
+
+ joined_room_ids = frozenset(joined_room_ids)
+ defer.returnValue(joined_room_ids)
+
def _action_has_highlight(actions):
for action in actions:
@@ -1241,7 +1485,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user"
- def __init__(self, sync_config, full_state, since_token, now_token):
+ def __init__(self, sync_config, full_state, since_token, now_token,
+ joined_room_ids):
"""
Args:
sync_config(SyncConfig)
@@ -1253,6 +1498,7 @@ class SyncResultBuilder(object):
self.full_state = full_state
self.since_token = since_token
self.now_token = now_token
+ self.joined_room_ids = joined_room_ids
self.presence = []
self.account_data = []
@@ -1260,6 +1506,8 @@ class SyncResultBuilder(object):
self.invited = []
self.archived = []
self.device = []
+ self.groups = None
+ self.to_device = []
class RoomSyncResultBuilder(object):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 82dedbbc99..2d2d3d5a0d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -13,17 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+from collections import namedtuple
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.api.errors import AuthError, SynapseError
+from synapse.types import UserID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-from synapse.types import UserID, get_domain_from_id
-
-import logging
-
-from collections import namedtuple
logger = logging.getLogger(__name__)
@@ -56,7 +55,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender()
- hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
+ hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@@ -97,7 +96,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- preserve_fn(self._push_remote)(
+ run_in_background(
+ self._push_remote,
member=member,
typing=True
)
@@ -196,7 +196,7 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- preserve_fn(self._push_remote)(member, typing)
+ run_in_background(self._push_remote, member, typing)
self._push_update_local(
member=member,
@@ -205,28 +205,31 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
- users = yield self.state.get_current_user_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
+ try:
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
- )
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- for domain in set(get_domain_from_id(u) for u in users):
- if domain != self.server_name:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 2a49456bfc..37dda64587 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,18 +14,20 @@
# limitations under the License.
import logging
+
+from six import iteritems
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
+from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure
-from synapse.util.async import sleep
-
logger = logging.getLogger(__name__)
-class UserDirectoyHandler(object):
+class UserDirectoryHandler(object):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@@ -41,9 +43,10 @@ class UserDirectoyHandler(object):
one public room.
"""
- INITIAL_SLEEP_MS = 50
- INITIAL_SLEEP_COUNT = 100
- INITIAL_BATCH_SIZE = 100
+ INITIAL_ROOM_SLEEP_MS = 50
+ INITIAL_ROOM_SLEEP_COUNT = 100
+ INITIAL_ROOM_BATCH_SIZE = 100
+ INITIAL_USER_SLEEP_MS = 10
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -53,6 +56,7 @@ class UserDirectoyHandler(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
+ self.search_all_users = hs.config.user_directory_search_all_users
# When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already
@@ -111,6 +115,22 @@ class UserDirectoyHandler(object):
self._is_processing = False
@defer.inlineCallbacks
+ def handle_local_profile_change(self, user_id, profile):
+ """Called to update index of our local user profiles when they change
+ irrespective of any rooms the user may be in.
+ """
+ yield self.store.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url, None,
+ )
+
+ @defer.inlineCallbacks
+ def handle_user_deactivated(self, user_id):
+ """Called when a user ID is deactivated
+ """
+ yield self.store.remove_from_user_dir(user_id)
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ @defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
@@ -148,16 +168,30 @@ class UserDirectoyHandler(object):
room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
- num_processed_rooms = 1
+ num_processed_rooms = 0
for room_id in room_ids:
- logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
- yield self._handle_intial_room(room_id)
+ logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
+ yield self._handle_initial_room(room_id)
num_processed_rooms += 1
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
+ if self.search_all_users:
+ num_processed_users = 0
+ user_ids = yield self.store.get_all_local_users()
+ logger.info("Doing initial update of user directory. %d users", len(user_ids))
+ for user_id in user_ids:
+ # We add profiles for all users even if they don't match the
+ # include pattern, just in case we want to change it in future
+ logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
+ yield self._handle_local_user(user_id)
+ num_processed_users += 1
+ yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+
+ logger.info("Processed all users")
+
self.initially_handled_users = None
self.initially_handled_users_in_public = None
self.initially_handled_users_share = None
@@ -166,7 +200,7 @@ class UserDirectoyHandler(object):
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
- def _handle_intial_room(self, room_id):
+ def _handle_initial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
@@ -201,8 +235,8 @@ class UserDirectoyHandler(object):
to_update = set()
count = 0
for user_id in user_ids:
- if count % self.INITIAL_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
@@ -216,8 +250,8 @@ class UserDirectoyHandler(object):
if user_id == other_user_id:
continue
- if count % self.INITIAL_SLEEP_COUNT == 0:
- yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+ if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+ yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)
@@ -237,13 +271,13 @@ class UserDirectoyHandler(object):
else:
self.initially_handled_users_share_private_room.add(user_set)
- if len(to_insert) > self.INITIAL_BATCH_SIZE:
+ if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
to_insert.clear()
- if len(to_update) > self.INITIAL_BATCH_SIZE:
+ if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
@@ -377,7 +411,7 @@ class UserDirectoyHandler(object):
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
- for user_id, profile in users_with_profile.iteritems():
+ for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
@@ -385,14 +419,28 @@ class UserDirectoyHandler(object):
yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
+ def _handle_local_user(self, user_id):
+ """Adds a new local roomless user into the user_directory_search table.
+ Used to populate up the user index when we have an
+ user_directory_search_all_users specified.
+ """
+ logger.debug("Adding new local user to dir, %r", user_id)
+
+ profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
+
+ row = yield self.store.get_user_in_directory(user_id)
+ if not row:
+ yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+
+ @defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
- room_id (str): room_id that user joined or started being public that
+ room_id (str): room_id that user joined or started being public
user_id (str)
"""
- logger.debug("Adding user to dir, %r", user_id)
+ logger.debug("Adding new user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
if not row:
@@ -407,7 +455,7 @@ class UserDirectoyHandler(object):
if not row:
yield self.store.add_users_to_public_room(room_id, [user_id])
else:
- logger.debug("Not adding user to public dir, %r", user_id)
+ logger.debug("Not adding new user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..58ef8d3ce4 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +13,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+ """Exception representing timeout of an outbound request"""
+ def __init__(self):
+ super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+ """Turns CancelledErrors into RequestTimedOutErrors.
+
+ For use with async.add_timeout_to_deferred
+ """
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise RequestTimedOutError()
+ return value
+
+
+ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+
+
+def redact_uri(uri):
+ """Strips access tokens from the uri replaces with <redacted>"""
+ return ACCESS_TOKEN_RE.sub(
+ br'\1<redacted>\3',
+ uri
+ )
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
new file mode 100644
index 0000000000..0e10e3f8f7
--- /dev/null
+++ b/synapse/http/additional_resource.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.http.server import wrap_json_request_handler
+
+
+class AdditionalResource(Resource):
+ """Resource wrapper for additional_resources
+
+ If the user has configured additional_resources, we need to wrap the
+ handler class with a Resource so that we can map it into the resource tree.
+
+ This class is also where we wrap the request handler with logging, metrics,
+ and exception handling.
+ """
+ def __init__(self, hs, handler):
+ """Initialise AdditionalResource
+
+ The ``handler`` should return a deferred which completes when it has
+ done handling the request. It should write a response with
+ ``request.write()``, and call ``request.finish()``.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
+ function to be called to handle the request.
+ """
+ Resource.__init__(self)
+ self._handler = handler
+
+ # required by the request_handler wrapper
+ self.clock = hs.get_clock()
+
+ def render(self, request):
+ self._async_render(request)
+ return NOT_DONE_YET
+
+ @wrap_json_request_handler
+ def _async_render(self, request):
+ return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9eba046bbf..25b6307884 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,49 +13,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from OpenSSL import SSL
-from OpenSSL.SSL import VERIFY_NONE
+import logging
+import urllib
-from synapse.api.errors import (
- CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
-)
-from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util import logcontext
-import synapse.metrics
-from synapse.http.endpoint import SpiderEndpoint
+from six import StringIO
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
+from prometheus_client import Counter
-from twisted.internet import defer, reactor, ssl, protocol, task
+from OpenSSL import SSL
+from OpenSSL.SSL import VERIFY_NONE
+from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.web._newclient import ResponseDone
from twisted.web.client import (
- BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
- readBody, PartialDownloadError,
+ Agent,
+ BrowserLikeRedirectAgent,
+ ContentDecoderAgent,
+ FileBodyProducer as TwistedFileBodyProducer,
+ GzipDecoder,
+ HTTPConnectionPool,
+ PartialDownloadError,
+ readBody,
)
-from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web._newclient import ResponseDone
-
-from StringIO import StringIO
-
-import simplejson as json
-import logging
-import urllib
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ MatrixCodeMessageException,
+ SynapseError,
+)
+from synapse.http import cancelled_to_request_timed_out_error, redact_uri
+from synapse.http.endpoint import SpiderEndpoint
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
+incoming_responses_counter = Counter("synapse_http_client_responses", "",
+ ["method", "code"])
class SimpleHttpClient(object):
@@ -64,13 +65,23 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
+
+ pool = HTTPConnectionPool(reactor)
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ pool.cachedConnectionTimeout = 2 * 60
+
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory()
+ contextFactory=hs.get_http_client_context_factory(),
+ pool=pool,
)
self.user_agent = hs.version_string
self.clock = hs.get_clock()
@@ -81,76 +92,103 @@ class SimpleHttpClient(object):
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
- outgoing_requests_counter.inc(method)
+ outgoing_requests_counter.labels(method).inc()
- def send_request():
+ # log request but strip `access_token` (AS requests for example include this)
+ logger.info("Sending request %s %s", method, redact_uri(uri))
+
+ try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
)
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=60,
+ add_timeout_to_deferred(
+ request_deferred, 60, self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
)
+ response = yield make_deferred_yieldable(request_deferred)
- logger.info("Sending request %s %s", method, uri)
-
- try:
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
-
- incoming_responses_counter.inc(method, response.code)
+ incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
- method, uri, response.code
+ method, redact_uri(uri), response.code
)
defer.returnValue(response)
except Exception as e:
- incoming_responses_counter.inc(method, "ERR")
+ incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
- method, uri, type(e).__name__, e.message
+ method, redact_uri(uri), type(e).__name__, e.message
)
- raise e
+ raise
@defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}):
+ def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ """
+ Args:
+ uri (str):
+ args (dict[str, str|List[str]]): query params
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
+
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ actual_headers = {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json):
+ def post_json_get_json(self, uri, post_json, headers=None):
+ """
+
+ Args:
+ uri (str):
+ post_json (object):
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/json"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -160,7 +198,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, uri, args={}):
+ def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -169,6 +207,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -177,13 +217,13 @@ class SimpleHttpClient(object):
error message.
"""
try:
- body = yield self.get_raw(uri, args)
+ body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}):
+ def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -193,6 +233,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -205,17 +247,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"PUT",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- "Content-Type": ["application/json"]
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -226,7 +272,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
- def get_raw(self, uri, args={}):
+ def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -235,6 +281,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@@ -246,15 +294,19 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
@@ -274,27 +326,33 @@ class SimpleHttpClient(object):
# The two should be factored out.
@defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None):
+ def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
url.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- headers = dict(response.headers.getAllRawHeaders())
+ resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -315,10 +373,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ length = yield make_deferred_yieldable(_readBodyToFile(
+ response, output_stream, max_size,
+ ))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@@ -327,7 +384,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
- defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+ defer.returnValue(
+ (length, resp_headers, response.request.absoluteURI, response.code),
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -395,7 +454,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
@@ -446,7 +505,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..d65daa72bb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,30 +12,97 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-
import collections
import logging
import random
+import re
import time
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+ "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == '[':
+ if host[-1] != ']':
+ raise ValueError("Mismatched [...] in server name '%s'" % (
+ server_name,
+ ))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError("Server name '%s' contains invalid characters" % (
+ server_name,
+ ))
+
+ return host, port
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
@@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ))
+ ), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ))
+ ), reactor)
class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac):
+ def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
+ self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn)
+ conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
@@ -96,9 +162,10 @@ class _WrappedConnection(object):
"""
__slots__ = ["conn", "last_request"]
- def __init__(self, conn):
+ def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
+ self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
@@ -113,10 +180,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the underlying TLS connection which tries to
+ # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the underlying TCP connection.
- self.transport.abortConnection()
+ # since that will promptly close the TLS connection.
+ #
+ # In Twisted >18.4; the TLS connection will be None if it has closed
+ # which will make abortConnection() throw. Check that the TLS connection
+ # is not None before trying to close it.
+ if self.transport.getHandle() is not None:
+ self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
@@ -124,14 +196,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
@@ -219,9 +291,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -231,11 +304,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
@@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
-
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
-
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + answer.ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 747a791f83..bf1aa29502 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,48 +13,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.retryutils
-from twisted.internet import defer, reactor, protocol
-from twisted.internet.error import DNSLookupError
-from twisted.web.client import readBody, HTTPConnectionPool, Agent
-from twisted.web.http_headers import Headers
-from twisted.web._newclient import ResponseDone
-
-from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util import logcontext
-import synapse.metrics
-
-from canonicaljson import encode_canonical_json
-
-from synapse.api.errors import (
- SynapseError, Codes, HttpResponseException,
-)
-
-from signedjson.sign import sign_json
-
import cgi
-import simplejson as json
import logging
import random
import sys
import urllib
-import urlparse
+from six import string_types
+from six.moves.urllib import parse as urlparse
-logger = logging.getLogger(__name__)
-outbound_logger = logging.getLogger("synapse.http.outbound")
+from canonicaljson import encode_canonical_json, json
+from prometheus_client import Counter
+from signedjson.sign import sign_json
-metrics = synapse.metrics.get_metrics_for(__name__)
+from twisted.internet import defer, protocol, reactor
+from twisted.internet.error import DNSLookupError
+from twisted.web._newclient import ResponseDone
+from twisted.web.client import Agent, HTTPConnectionPool, readBody
+from twisted.web.http_headers import Headers
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
+import synapse.metrics
+import synapse.util.retryutils
+from synapse.api.errors import (
+ Codes,
+ FederationDeniedError,
+ HttpResponseException,
+ SynapseError,
)
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.util import logcontext
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+outbound_logger = logging.getLogger("synapse.http.outbound")
+
+outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
+ "", ["method"])
+incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
+ "", ["method", "code"])
MAX_LONG_RETRIES = 10
@@ -123,11 +122,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
+
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
+ if (
+ self.hs.config.federation_domain_whitelist and
+ destination not in self.hs.config.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(destination)
+
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
@@ -173,21 +183,21 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, http_url_bytes, headers_dict)
try:
- def send_request():
- request_deferred = self.agent.request(
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout / 1000. if timeout else 60,
- )
-
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
+ request_deferred = self.agent.request(
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ timeout / 1000. if timeout else 60,
+ self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(
+ request_deferred,
+ )
log_result = "%d %s" % (response.code, response.phrase,)
break
@@ -204,18 +214,15 @@ class MatrixFederationHttpClient(object):
raise
logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
+ "{%s} Sending request failed to %s: %s %s: %s",
txn_id,
destination,
method,
url_bytes,
- type(e).__name__,
_flatten_response_never_received(e),
)
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
+ log_result = _flatten_response_never_received(e)
if retries_left and not timeout:
if long_retries:
@@ -227,7 +234,7 @@ class MatrixFederationHttpClient(object):
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
- yield sleep(delay)
+ yield self.clock.sleep(delay)
retries_left -= 1
else:
raise
@@ -253,14 +260,35 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
- content=None):
+ content=None, destination_is=None):
+ """
+ Signs a request by adding an Authorization header to headers_dict
+ Args:
+ destination (bytes|None): The desination home server of the request.
+ May be None if the destination is an identity server, in which case
+ destination_is must be non-None.
+ method (bytes): The HTTP method of the request
+ url_bytes (bytes): The URI path of the request
+ headers_dict (dict): Dictionary of request headers to append to
+ content (bytes): The body of the request
+ destination_is (bytes): As 'destination', but if the destination is an
+ identity server
+
+ Returns:
+ None
+ """
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
- "destination": destination,
}
+ if destination is not None:
+ request["destination"] = destination
+
+ if destination_is is not None:
+ request["destination_is"] = destination_is
+
if content is not None:
request["content"] = content
@@ -278,7 +306,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, data={}, json_data_callback=None,
+ def put_json(self, destination, path, args={}, data={},
+ json_data_callback=None,
long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
@@ -288,6 +317,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
+ args (dict): query params
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
@@ -311,6 +341,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
if not json_data_callback:
@@ -331,6 +364,7 @@ class MatrixFederationHttpClient(object):
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ query_bytes=encode_query_args(args),
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -347,7 +381,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False):
+ timeout=None, ignore_backoff=False, args={}):
""" Sends the specifed json data using POST
Args:
@@ -362,6 +396,7 @@ class MatrixFederationHttpClient(object):
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+ args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
@@ -371,6 +406,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -383,6 +421,7 @@ class MatrixFederationHttpClient(object):
destination,
"POST",
path,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
@@ -424,16 +463,12 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, basestring):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
@@ -444,7 +479,7 @@ class MatrixFederationHttpClient(object):
destination,
"GET",
path,
- query_bytes=query_bytes,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
@@ -461,6 +496,55 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
+ def delete_json(self, destination, path, long_retries=False,
+ timeout=None, ignore_backoff=False, args={}):
+ """Send a DELETE request to the remote expecting some json response
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+ """
+
+ response = yield self._request(
+ destination,
+ "DELETE",
+ path,
+ query_bytes=encode_query_args(args),
+ headers_dict={"Content-Type": ["application/json"]},
+ long_retries=long_retries,
+ timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ )
+
+ if 200 <= response.code < 300:
+ # We need to update the transactions table to say it was sent?
+ check_content_type_is_json(response.headers)
+
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
@@ -481,11 +565,14 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
@@ -513,7 +600,7 @@ class MatrixFederationHttpClient(object):
length = yield _readBodyToFile(
response, output_stream, max_size
)
- except:
+ except Exception:
logger.exception("Failed to download body")
raise
@@ -578,12 +665,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
- return ", ".join(
+ reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
+
+ return "%s:[%s]" % (type(e).__name__, reasons)
else:
- return "%s: %s" % (type(e).__name__, e.message,)
+ return repr(e)
def check_content_type_is_json(headers):
@@ -598,7 +687,7 @@ def check_content_type_is_json(headers):
RuntimeError if the
"""
- c_type = headers.getRawHeaders("Content-Type")
+ c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
@@ -610,3 +699,15 @@ def check_content_type_is_json(headers):
raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)
+
+
+def encode_query_args(args):
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, string_types):
+ vs = [vs]
+ encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+ query_bytes = urllib.urlencode(encoded_args, True)
+
+ return query_bytes
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
new file mode 100644
index 0000000000..588e280571
--- /dev/null
+++ b/synapse/http/request_metrics.py
@@ -0,0 +1,231 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from prometheus_client.core import Counter, Histogram
+
+from synapse.metrics import LaterGauge
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
+
+# total number of responses served, split by method/servlet/tag
+response_count = Counter(
+ "synapse_http_server_response_count", "", ["method", "servlet", "tag"]
+)
+
+requests_counter = Counter(
+ "synapse_http_server_requests_received", "", ["method", "servlet"]
+)
+
+outgoing_responses_counter = Counter(
+ "synapse_http_server_responses", "", ["method", "code"]
+)
+
+response_timer = Histogram(
+ "synapse_http_server_response_time_seconds", "sec",
+ ["method", "servlet", "tag", "code"],
+)
+
+response_ru_utime = Counter(
+ "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_stime = Counter(
+ "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_db_txn_count = Counter(
+ "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = Counter(
+ "synapse_http_server_response_db_txn_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = Counter(
+ "synapse_http_server_response_db_sched_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# size in bytes of the response written
+response_size = Counter(
+ "synapse_http_server_response_size", "", ["method", "servlet", "tag"]
+)
+
+# In flight metrics are incremented while the requests are in flight, rather
+# than when the response was written.
+
+in_flight_requests_ru_utime = Counter(
+ "synapse_http_server_in_flight_requests_ru_utime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_ru_stime = Counter(
+ "synapse_http_server_in_flight_requests_ru_stime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_db_txn_count = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+in_flight_requests_db_txn_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+in_flight_requests_db_sched_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_sched_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# The set of all in flight requests, set[RequestMetrics]
+_in_flight_requests = set()
+
+
+def _get_in_flight_counts():
+ """Returns a count of all in flight requests by (method, server_name)
+
+ Returns:
+ dict[tuple[str, str], int]
+ """
+ # Cast to a list to prevent it changing while the Prometheus
+ # thread is collecting metrics
+ reqs = list(_in_flight_requests)
+
+ for rm in reqs:
+ rm.update_metrics()
+
+ # Map from (method, name) -> int, the number of in flight requests of that
+ # type
+ counts = {}
+ for rm in reqs:
+ key = (rm.method, rm.name,)
+ counts[key] = counts.get(key, 0) + 1
+
+ return counts
+
+
+LaterGauge(
+ "synapse_http_server_in_flight_requests_count",
+ "",
+ ["method", "servlet"],
+ _get_in_flight_counts,
+)
+
+
+class RequestMetrics(object):
+ def start(self, time_sec, name, method):
+ self.start = time_sec
+ self.start_context = LoggingContext.current_context()
+ self.name = name
+ self.method = method
+
+ # _request_stats records resource usage that we have already added
+ # to the "in flight" metrics.
+ self._request_stats = self.start_context.get_resource_usage()
+
+ _in_flight_requests.add(self)
+
+ def stop(self, time_sec, request):
+ _in_flight_requests.discard(self)
+
+ context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ response_code = str(request.code)
+
+ outgoing_responses_counter.labels(request.method, response_code).inc()
+
+ response_count.labels(request.method, self.name, tag).inc()
+
+ response_timer.labels(request.method, self.name, tag, response_code).observe(
+ time_sec - self.start
+ )
+
+ resource_usage = context.get_resource_usage()
+
+ response_ru_utime.labels(request.method, self.name, tag).inc(
+ resource_usage.ru_utime,
+ )
+ response_ru_stime.labels(request.method, self.name, tag).inc(
+ resource_usage.ru_stime,
+ )
+ response_db_txn_count.labels(request.method, self.name, tag).inc(
+ resource_usage.db_txn_count
+ )
+ response_db_txn_duration.labels(request.method, self.name, tag).inc(
+ resource_usage.db_txn_duration_sec
+ )
+ response_db_sched_duration.labels(request.method, self.name, tag).inc(
+ resource_usage.db_sched_duration_sec
+ )
+
+ response_size.labels(request.method, self.name, tag).inc(request.sentLength)
+
+ # We always call this at the end to ensure that we update the metrics
+ # regardless of whether a call to /metrics while the request was in
+ # flight.
+ self.update_metrics()
+
+ def update_metrics(self):
+ """Updates the in flight metrics with values from this request.
+ """
+ new_stats = self.start_context.get_resource_usage()
+
+ diff = new_stats - self._request_stats
+ self._request_stats = new_stats
+
+ in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime)
+ in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime)
+
+ in_flight_requests_db_txn_count.labels(self.method, self.name).inc(
+ diff.db_txn_count
+ )
+
+ in_flight_requests_db_txn_duration.labels(self.method, self.name).inc(
+ diff.db_txn_duration_sec
+ )
+
+ in_flight_requests_db_sched_duration.labels(self.method, self.name).inc(
+ diff.db_sched_duration_sec
+ )
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 7ef3d526b1..c70fdbdfd2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,148 +13,205 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import collections
+import logging
+import urllib
+from six.moves import http_client
-from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
-)
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import intern_dict
-from synapse.util.metrics import Measure
-import synapse.metrics
-import synapse.events
-
-from canonicaljson import (
- encode_canonical_json, encode_pretty_printed_json
-)
+from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
from twisted.internet import defer
-from twisted.web import server, resource
+from twisted.python import failure
+from twisted.web import resource, server
from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo
-import collections
-import logging
-import urllib
-import ujson
+import synapse.events
+import synapse.metrics
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ SynapseError,
+ UnrecognizedRequestError,
+ cs_exception,
+)
+from synapse.http.request_metrics import requests_counter
+from synapse.util.caches import intern_dict
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
+HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
+<html lang=en>
+ <head>
+ <meta charset="utf-8">
+ <title>Error {code}</title>
+ </head>
+ <body>
+ <p>{msg}</p>
+ </body>
+</html>
+"""
-incoming_requests_counter = metrics.register_counter(
- "requests",
- labels=["method", "servlet", "tag"],
-)
-outgoing_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
-response_timer = metrics.register_distribution(
- "response_time",
- labels=["method", "servlet", "tag"]
-)
+def wrap_json_request_handler(h):
+ """Wraps a request handler method with exception handling.
-response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet", "tag"]
-)
+ Also adds logging as per wrap_request_handler_with_logging.
-response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet", "tag"]
-)
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
-response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet", "tag"]
-)
+ The handler must return a deferred. If the deferred succeeds we assume that
+ a response has been sent. If the deferred fails with a SynapseError we use
+ it to send a JSON response with the appropriate HTTP reponse code. If the
+ deferred fails with any other type of error we send a 500 reponse.
+ """
-response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet", "tag"]
-)
+ @defer.inlineCallbacks
+ def wrapped_request_handler(self, request):
+ try:
+ yield h(self, request)
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ logger.error(
+ "Failed handle request via %r: %r: %s",
+ h,
+ request,
+ f.getTraceback().rstrip(),
+ )
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-_next_request_id = 0
+ return wrap_request_handler_with_logging(wrapped_request_handler)
-def request_handler(include_metrics=False):
- """Decorator for ``wrap_request_handler``"""
- return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
+def wrap_html_request_handler(h):
+ """Wraps a request handler method with exception handling.
+ Also adds logging as per wrap_request_handler_with_logging.
-def wrap_request_handler(request_handler, include_metrics=False):
- """Wraps a method that acts as a request handler with the necessary logging
- and exception handling.
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+ """
+ def wrapped_request_handler(self, request):
+ d = defer.maybeDeferred(h, self, request)
+ d.addErrback(_return_html_error, request)
+ return d
- The method must have a signature of "handle_foo(self, request)". The
- argument "self" must have "version_string" and "clock" attributes. The
- argument "request" must be a twisted HTTP request.
+ return wrap_request_handler_with_logging(wrapped_request_handler)
- The method must return a deferred. If the deferred succeeds we assume that
- a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
- We insert a unique request-id into the logging context for this request and
- log the response and duration for this request.
+def _return_html_error(f, request):
+ """Sends an HTML error page corresponding to the given failure
+
+ Args:
+ f (twisted.python.failure.Failure):
+ request (twisted.web.iweb.IRequest):
"""
+ if f.check(CodeMessageException):
+ cme = f.value
+ code = cme.code
+ msg = cme.msg
+
+ if isinstance(cme, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, msg
+ )
+ else:
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+ else:
+ code = http_client.INTERNAL_SERVER_ERROR
+ msg = "Internal server error"
+
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+
+ body = HTML_ERROR_TEMPLATE.format(
+ code=code, msg=cgi.escape(msg),
+ ).encode("utf-8")
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % (len(body),))
+ request.write(body)
+ finish_request(request)
+
+def wrap_request_handler_with_logging(h):
+ """Wraps a request handler to provide logging and metrics
+
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+
+ As well as calling `request.processing` (which will log the response and
+ duration for this request), the wrapped request handler will insert the
+ request id into the logging context.
+ """
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
- global _next_request_id
- request_id = "%s-%s" % (request.method, _next_request_id)
- _next_request_id += 1
+ """
+ Args:
+ self:
+ request (synapse.http.site.SynapseRequest):
+ """
+ request_id = request.get_request_id()
with LoggingContext(request_id) as request_context:
+ request_context.request = request_id
with Measure(self.clock, "wrapped_request_handler"):
- request_metrics = RequestMetrics()
- request_metrics.start(self.clock, name=self.__class__.__name__)
-
- request_context.request = request_id
- with request.processing():
- try:
- with PreserveLoggingContext(request_context):
- if include_metrics:
- yield request_handler(self, request, request_metrics)
- else:
- yield request_handler(self, request)
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except:
- logger.exception(
- "Failed handle request %s.%s on %r: %r",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request
- )
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True
- )
- finally:
- try:
- request_metrics.stop(
- self.clock, request
- )
- except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
+ servlet_name = self.__class__.__name__
+ with request.processing(servlet_name):
+ with PreserveLoggingContext(request_context):
+ d = defer.maybeDeferred(h, self, request)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(request.method,
+ request.request_metrics.name).inc()
+ yield d
return wrapped_request_handler
@@ -183,7 +241,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
- Register callbacks via register_path()
+ Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
@@ -203,7 +261,6 @@ class JsonResource(HttpServer, resource.Resource):
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
- self.version_string = hs.version_string
self.hs = hs
def register_paths(self, method, path_patterns, callback):
@@ -219,122 +276,103 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
- # Disable metric reporting because _async_render does its own metrics.
- # It does its own metric reporting because _async_render dispatches to
- # a callback and it's the class name of that callback we want to report
- # against rather than the JsonResource itself.
- @request_handler(include_metrics=True)
+ @wrap_json_request_handler
@defer.inlineCallbacks
- def _async_render(self, request, request_metrics):
+ def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
- if request.method == "OPTIONS":
- self._send_response(request, 200, {})
- return
+ callback, group_dict = self._get_handler_for_request(request)
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path)
- if not m:
- continue
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+ request.request_metrics.name = servlet_classname
- # We found a match! Trigger callback and then return the
- # returned response. We pass both the request and any
- # matched groups from the regex to the callback.
+ # Now trigger the callback. If it returns a response, we send it
+ # here. If it throws an exception, that is handled by the wrapper
+ # installed by @request_handler.
- callback = path_entry.callback
+ kwargs = intern_dict({
+ name: urllib.unquote(value).decode("UTF-8") if value else value
+ for name, value in group_dict.items()
+ })
- kwargs = intern_dict({
- name: urllib.unquote(value).decode("UTF-8") if value else value
- for name, value in m.groupdict().items()
- })
+ callback_return = yield callback(request, **kwargs)
+ if callback_return is not None:
+ code, response = callback_return
+ self._send_response(request, code, response)
- callback_return = yield callback(request, **kwargs)
- if callback_return is not None:
- code, response = callback_return
- self._send_response(request, code, response)
+ def _get_handler_for_request(self, request):
+ """Finds a callback method to handle the given request
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
+ Args:
+ request (twisted.web.http.Request):
- request_metrics.name = servlet_classname
+ Returns:
+ Tuple[Callable, dict[str, str]]: callback method, and the dict
+ mapping keys to path components as specified in the handler's
+ path match regexp.
- return
+ The callback will normally be a method registered via
+ register_paths, so will return (possibly via Deferred) either
+ None, or a tuple of (http code, response body).
+ """
+ if request.method == b"OPTIONS":
+ return _options_handler, {}
+
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request.path)
+ if m:
+ # We found a match!
+ return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- raise UnrecognizedRequestError()
+ return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
- outgoing_responses_counter.inc(request.method, str(code))
-
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
canonical_json=self.canonical_json,
)
-class RequestMetrics(object):
- def start(self, clock, name):
- self.start = clock.time_msec()
- self.start_context = LoggingContext.current_context()
- self.name = name
+def _options_handler(request):
+ """Request handler for OPTIONS requests
- def stop(self, clock, request):
- context = LoggingContext.current_context()
+ This is a request handler suitable for return from
+ _get_handler_for_request. It returns a 200 and an empty body.
- tag = ""
- if context:
- tag = context.tag
+ Args:
+ request (twisted.web.http.Request):
- if context != self.start_context:
- logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
- )
- return
+ Returns:
+ Tuple[int, dict]: http code, response body.
+ """
+ return 200, {}
- incoming_requests_counter.inc(request.method, self.name, tag)
- response_timer.inc_by(
- clock.time_msec() - self.start, request.method,
- self.name, tag
- )
+def _unrecognised_request_handler(request):
+ """Request handler for unrecognised requests
- ru_utime, ru_stime = context.get_resource_usage()
+ This is a request handler suitable for return from
+ _get_handler_for_request. It actually just raises an
+ UnrecognizedRequestError.
- response_ru_utime.inc_by(
- ru_utime, request.method, self.name, tag
- )
- response_ru_stime.inc_by(
- ru_stime, request.method, self.name, tag
- )
- response_db_txn_count.inc_by(
- context.db_txn_count, request.method, self.name, tag
- )
- response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
- )
+ Args:
+ request (twisted.web.http.Request):
+ """
+ raise UnrecognizedRequestError()
class RootRedirect(resource.Resource):
@@ -355,26 +393,33 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
- version_string="", canonical_json=True):
+ canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- # ujson doesn't like frozen_dicts.
- json_bytes = ujson.dumps(json_object, ensure_ascii=False)
+ json_bytes = json.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
- version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- version_string="", response_code_message=None):
+ response_code_message=None):
"""Sends encoded JSON in response to the given request.
Args:
@@ -388,8 +433,8 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json")
- request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
+ request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
@@ -437,9 +482,9 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
- "User-Agent", default=[]
+ b"User-Agent", default=[]
)
for user_agent in user_agents:
- if "curl" in user_agent:
+ if b"curl" in user_agent:
return True
return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 9a4c36ad5d..882816dc8f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,10 +15,11 @@
""" This module contains base REST classes for constructing REST servlets. """
-from synapse.api.errors import SynapseError, Codes
-
import logging
-import simplejson
+
+from canonicaljson import json
+
+from synapse.api.errors import Codes, SynapseError
logger = logging.getLogger(__name__)
@@ -48,7 +49,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try:
return int(args[name][0])
- except:
+ except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
@@ -88,7 +89,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
"true": True,
"false": False,
}[args[name][0]]
- except:
+ except Exception:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
@@ -148,11 +149,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into None
Returns:
The JSON value.
@@ -162,28 +165,39 @@ def parse_json_value_from_request(request):
"""
try:
content_bytes = request.content.read()
- except:
+ except Exception:
raise SynapseError(400, "Error reading JSON content.")
+ if not content_bytes and allow_empty_body:
+ return None
+
try:
- content = simplejson.loads(content_bytes)
- except simplejson.JSONDecodeError:
+ content = json.loads(content_bytes)
+ except Exception as e:
+ logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into an empty dict.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(request)
+ content = parse_json_value_from_request(
+ request, allow_empty_body=allow_empty_body,
+ )
+
+ if allow_empty_body and content is None:
+ return {}
if type(content) != dict:
message = "Content must be a JSON object."
@@ -192,7 +206,7 @@ def parse_json_object_from_request(request):
return content
-def assert_params_in_request(body, required):
+def assert_params_in_dict(body, required):
absent = []
for k in required:
if k not in body:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4b09d7ee66..5fd30a4c2c 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -12,27 +12,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
-from twisted.web.server import Site, Request
-
import contextlib
import logging
-import re
import time
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+from twisted.web.server import Request, Site
+
+from synapse.http import redact_uri
+from synapse.http.request_metrics import RequestMetrics
+from synapse.util.logcontext import ContextResourceUsage, LoggingContext
+
+logger = logging.getLogger(__name__)
+
+_next_request_seq = 0
class SynapseRequest(Request):
- def __init__(self, site, *args, **kw):
- Request.__init__(self, *args, **kw)
+ """Class which encapsulates an HTTP request to synapse.
+
+ All of the requests processed in synapse are of this type.
+
+ It extends twisted's twisted.web.server.Request, and adds:
+ * Unique request ID
+ * Redaction of access_token query-params in __repr__
+ * Logging at start and end
+ * Metrics to record CPU, wallclock and DB time by endpoint.
+
+ It provides a method `processing` which should be called by the Resource
+ which is handling the request, and returns a context manager.
+
+ """
+ def __init__(self, site, channel, *args, **kw):
+ Request.__init__(self, channel, *args, **kw)
self.site = site
+ self._channel = channel
self.authenticated_entity = None
self.start_time = 0
+ global _next_request_seq
+ self.request_seq = _next_request_seq
+ _next_request_seq += 1
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
self.method,
@@ -41,16 +64,27 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def get_request_id(self):
+ return "%s-%i" % (self.method, self.request_seq)
+
def get_redacted_uri(self):
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- self.uri
- )
+ return redact_uri(self.uri)
def get_user_agent(self):
- return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+ return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
+
+ def render(self, resrc):
+ # override the Server header which is set by twisted
+ self.setHeader("Server", self.site.server_version_string)
+ return Request.render(self, resrc)
+
+ def _started_processing(self, servlet_name):
+ self.start_time = time.time()
+ self.request_metrics = RequestMetrics()
+ self.request_metrics.start(
+ self.start_time, name=servlet_name, method=self.method,
+ )
- def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
@@ -58,44 +92,85 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri()
)
- self.start_time = int(time.time() * 1000)
-
- def finished_processing(self):
+ def _finished_processing(self):
try:
context = LoggingContext.current_context()
- ru_utime, ru_stime = context.get_resource_usage()
- db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
- except:
- ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ usage = context.get_resource_usage()
+ except Exception:
+ usage = ContextResourceUsage()
+
+ end_time = time.time()
+
+ # need to decode as it could be raw utf-8 bytes
+ # from a IDN servname in an auth header
+ authenticated_entity = self.authenticated_entity
+ if authenticated_entity is not None:
+ authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+
+ # ...or could be raw utf-8 bytes in the User-Agent header.
+ # N.B. if you don't do this, the logger explodes cryptically
+ # with maximum recursion trying to log errors about
+ # the charset problem.
+ # c.f. https://github.com/matrix-org/synapse/issues/3471
+ user_agent = self.get_user_agent()
+ if user_agent is not None:
+ user_agent = user_agent.decode("utf-8", "replace")
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
- " %sB %s \"%s %s %s\" \"%s\"",
+ " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
- self.authenticated_entity,
- int(time.time() * 1000) - self.start_time,
- int(ru_utime * 1000),
- int(ru_stime * 1000),
- int(db_txn_duration * 1000),
- int(db_txn_count),
+ authenticated_entity,
+ end_time - self.start_time,
+ usage.ru_utime,
+ usage.ru_stime,
+ usage.db_sched_duration_sec,
+ usage.db_txn_duration_sec,
+ int(usage.db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
- self.get_user_agent(),
+ user_agent,
+ usage.evt_db_fetch_count,
)
+ try:
+ self.request_metrics.stop(end_time, self)
+ except Exception as e:
+ logger.warn("Failed to stop metrics: %r", e)
+
@contextlib.contextmanager
- def processing(self):
- self.started_processing()
+ def processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ This will log the request's arrival. Once the context manager is
+ closed, the completion of the request will be logged, and the various
+ metrics will be updated.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.servlet_name.
+ """
+ # TODO: we should probably just move this into render() and finish(),
+ # to save having to call a separate method.
+ self._started_processing(servlet_name)
yield
- self.finished_processing()
+ self._finished_processing()
class XForwardedForRequest(SynapseRequest):
@@ -133,7 +208,8 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+ def __init__(self, logger_name, site_tag, config, resource,
+ server_version_string, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -141,6 +217,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
+ self.server_version_string = server_version_string
def log(self, request):
pass
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..a9158fc066 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -13,118 +13,198 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import functools
-import time
import gc
+import logging
+import os
+import platform
+import time
-from twisted.internet import reactor
-
-from .metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
- MemoryUsageMetric,
-)
-from .process_collector import register_process_collector
+import attr
+from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client.core import REGISTRY, GaugeMetricFamily
+from twisted.internet import reactor
logger = logging.getLogger(__name__)
-
+running_on_pypy = platform.python_implementation() == "PyPy"
all_metrics = []
all_collectors = []
+all_gauges = {}
+HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
-class Metrics(object):
- """ A single Metrics object gives a (mutable) slice view of the all_metrics
- dict, allowing callers to easily register new metrics that are namespaced
- nicely."""
- def __init__(self, name):
- self.name_prefix = name
+class RegistryProxy(object):
- def make_subspace(self, name):
- return Metrics("%s_%s" % (self.name_prefix, name))
+ @staticmethod
+ def collect():
+ for metric in REGISTRY.collect():
+ if not metric.name.startswith("__"):
+ yield metric
- def register_collector(self, func):
- all_collectors.append(func)
- def _register(self, metric_class, name, *args, **kwargs):
- full_name = "%s_%s" % (self.name_prefix, name)
+@attr.s(hash=True)
+class LaterGauge(object):
- metric = metric_class(full_name, *args, **kwargs)
+ name = attr.ib()
+ desc = attr.ib()
+ labels = attr.ib(hash=False)
+ caller = attr.ib()
- all_metrics.append(metric)
- return metric
+ def collect(self):
- def register_counter(self, *args, **kwargs):
- return self._register(CounterMetric, *args, **kwargs)
+ g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
- def register_callback(self, *args, **kwargs):
- return self._register(CallbackMetric, *args, **kwargs)
+ try:
+ calls = self.caller()
+ except Exception:
+ logger.exception(
+ "Exception running callback for LaterGauge(%s)",
+ self.name,
+ )
+ yield g
+ return
- def register_distribution(self, *args, **kwargs):
- return self._register(DistributionMetric, *args, **kwargs)
+ if isinstance(calls, dict):
+ for k, v in calls.items():
+ g.add_metric(k, v)
+ else:
+ g.add_metric([], calls)
- def register_cache(self, *args, **kwargs):
- return self._register(CacheMetric, *args, **kwargs)
+ yield g
+ def __attrs_post_init__(self):
+ self._register()
-def register_memory_metrics(hs):
- try:
- import psutil
- process = psutil.Process()
- process.memory_info().rss
- except (ImportError, AttributeError):
- logger.warn(
- "psutil is not installed or incorrect version."
- " Disabling memory metrics."
- )
- return
- metric = MemoryUsageMetric(hs, psutil)
- all_metrics.append(metric)
+ def _register(self):
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
-def get_metrics_for(pkg_name):
- """ Returns a Metrics instance for conveniently creating metrics
- namespaced with the given name prefix. """
- # Convert a "package.name" to "package_name" because Prometheus doesn't
- # let us use . in metric names
- return Metrics(pkg_name.replace(".", "_"))
+#
+# Detailed CPU metrics
+#
+class CPUMetrics(object):
-def render_all():
- strs = []
+ def __init__(self):
+ ticks_per_sec = 100
+ try:
+ # Try and get the system config
+ ticks_per_sec = os.sysconf('SC_CLK_TCK')
+ except (ValueError, TypeError, AttributeError):
+ pass
- for collector in all_collectors:
- collector()
+ self.ticks_per_sec = ticks_per_sec
- for metric in all_metrics:
- try:
- strs += metric.render()
- except Exception:
- strs += ["# FAILED to render"]
- logger.exception("Failed to render metric")
+ def collect(self):
+ if not HAVE_PROC_SELF_STAT:
+ return
- strs.append("") # to generate a final CRLF
+ with open("/proc/self/stat") as s:
+ line = s.read()
+ raw_stats = line.split(") ", 1)[1].split(" ")
- return "\n".join(strs)
+ user = GaugeMetricFamily("process_cpu_user_seconds_total", "")
+ user.add_metric([], float(raw_stats[11]) / self.ticks_per_sec)
+ yield user
+ sys = GaugeMetricFamily("process_cpu_system_seconds_total", "")
+ sys.add_metric([], float(raw_stats[12]) / self.ticks_per_sec)
+ yield sys
-register_process_collector(get_metrics_for("process"))
+REGISTRY.register(CPUMetrics())
-python_metrics = get_metrics_for("python")
+#
+# Python GC metrics
+#
-gc_time = python_metrics.register_distribution("gc_time", labels=["gen"])
-gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
-python_metrics.register_callback(
- "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
+gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
+gc_time = Histogram(
+ "python_gc_time",
+ "Time taken to GC (sec)",
+ ["gen"],
+ buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50,
+ 5.00, 7.50, 15.00, 30.00, 45.00, 60.00],
)
-reactor_metrics = get_metrics_for("python.twisted.reactor")
-tick_time = reactor_metrics.register_distribution("tick_time")
-pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
+
+class GCCounts(object):
+
+ def collect(self):
+ cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
+ for n, m in enumerate(gc.get_count()):
+ cm.add_metric([str(n)], m)
+
+ yield cm
+
+
+if not running_on_pypy:
+ REGISTRY.register(GCCounts())
+
+#
+# Twisted reactor metrics
+#
+
+tick_time = Histogram(
+ "python_twisted_reactor_tick_time",
+ "Tick time of the Twisted reactor (sec)",
+ buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
+)
+pending_calls_metric = Histogram(
+ "python_twisted_reactor_pending_calls",
+ "Pending calls",
+ buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000],
+)
+
+#
+# Federation Metrics
+#
+
+sent_edus_counter = Counter("synapse_federation_client_sent_edus", "")
+
+sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "")
+
+events_processed_counter = Counter("synapse_federation_client_events_processed", "")
+
+# Used to track where various components have processed in the event stream,
+# e.g. federation sending, appservice sending, etc.
+event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])
+
+# Used to track the current max events stream position
+event_persisted_position = Gauge("synapse_event_persisted_position", "")
+
+# Used to track the received_ts of the last event processed by various
+# components
+event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"])
+
+# Used to track the lag processing events. This is the time difference
+# between the last processed event's received_ts and the time it was
+# finished being processed.
+event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
+
+last_ticked = time.time()
+
+
+class ReactorLastSeenMetric(object):
+
+ def collect(self):
+ cm = GaugeMetricFamily(
+ "python_twisted_reactor_last_seen",
+ "Seconds since the Twisted reactor was last seen",
+ )
+ cm.add_metric([], time.time() - last_ticked)
+ yield cm
+
+
+REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func):
@@ -146,12 +226,25 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
-
- start = time.time() * 1000
+ start = time.time()
ret = func(*args, **kwargs)
- end = time.time() * 1000
- tick_time.inc_by(end - start)
- pending_calls_metric.inc_by(num_pending)
+ end = time.time()
+
+ # record the amount of wallclock time spent running pending calls.
+ # This is a proxy for the actual amount of time between reactor polls,
+ # since about 25% of time is actually spent running things triggered by
+ # I/O events, but that is harder to capture without rewriting half the
+ # reactor.
+ tick_time.observe(end - start)
+ pending_calls_metric.observe(num_pending)
+
+ # Update the time we last ticked, for the metric to test whether
+ # Synapse's reactor has frozen
+ global last_ticked
+ last_ticked = end
+
+ if running_on_pypy:
+ return ret
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
@@ -161,12 +254,12 @@ def runUntilCurrentTimer(func):
if threshold[i] < counts[i]:
logger.info("Collecting gc %d", i)
- start = time.time() * 1000
+ start = time.time()
unreachable = gc.collect(i)
- end = time.time() * 1000
+ end = time.time()
- gc_time.inc_by(end - start, i)
- gc_unreachable.inc_by(unreachable, i)
+ gc_time.labels(i).observe(end - start)
+ gc_unreachable.labels(i).set(unreachable)
return ret
@@ -185,6 +278,7 @@ try:
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
- gc.disable()
+ if not running_on_pypy:
+ gc.disable()
except AttributeError:
pass
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
new file mode 100644
index 0000000000..9d820e44a6
--- /dev/null
+++ b/synapse/metrics/background_process_metrics.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+
+from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
+
+from twisted.internet import defer
+
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+
+_background_process_start_count = Counter(
+ "synapse_background_process_start_count",
+ "Number of background processes started",
+ ["name"],
+)
+
+# we set registry=None in all of these to stop them getting registered with
+# the default registry. Instead we collect them all via the CustomCollector,
+# which ensures that we can update them before they are collected.
+#
+_background_process_ru_utime = Counter(
+ "synapse_background_process_ru_utime_seconds",
+ "User CPU time used by background processes, in seconds",
+ ["name"],
+ registry=None,
+)
+
+_background_process_ru_stime = Counter(
+ "synapse_background_process_ru_stime_seconds",
+ "System CPU time used by background processes, in seconds",
+ ["name"],
+ registry=None,
+)
+
+_background_process_db_txn_count = Counter(
+ "synapse_background_process_db_txn_count",
+ "Number of database transactions done by background processes",
+ ["name"],
+ registry=None,
+)
+
+_background_process_db_txn_duration = Counter(
+ "synapse_background_process_db_txn_duration_seconds",
+ ("Seconds spent by background processes waiting for database "
+ "transactions, excluding scheduling time"),
+ ["name"],
+ registry=None,
+)
+
+_background_process_db_sched_duration = Counter(
+ "synapse_background_process_db_sched_duration_seconds",
+ "Seconds spent by background processes waiting for database connections",
+ ["name"],
+ registry=None,
+)
+
+# map from description to a counter, so that we can name our logcontexts
+# incrementally. (It actually duplicates _background_process_start_count, but
+# it's much simpler to do so than to try to combine them.)
+_background_process_counts = dict() # type: dict[str, int]
+
+# map from description to the currently running background processes.
+#
+# it's kept as a dict of sets rather than a big set so that we can keep track
+# of process descriptions that no longer have any active processes.
+_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
+
+
+class _Collector(object):
+ """A custom metrics collector for the background process metrics.
+
+ Ensures that all of the metrics are up-to-date with any in-flight processes
+ before they are returned.
+ """
+ def collect(self):
+ background_process_in_flight_count = GaugeMetricFamily(
+ "synapse_background_process_in_flight_count",
+ "Number of background processes in flight",
+ labels=["name"],
+ )
+
+ for desc, processes in six.iteritems(_background_processes):
+ background_process_in_flight_count.add_metric(
+ (desc,), len(processes),
+ )
+ for process in processes:
+ process.update_metrics()
+
+ yield background_process_in_flight_count
+
+ # now we need to run collect() over each of the static Counters, and
+ # yield each metric they return.
+ for m in (
+ _background_process_ru_utime,
+ _background_process_ru_stime,
+ _background_process_db_txn_count,
+ _background_process_db_txn_duration,
+ _background_process_db_sched_duration,
+ ):
+ for r in m.collect():
+ yield r
+
+
+REGISTRY.register(_Collector())
+
+
+class _BackgroundProcess(object):
+ def __init__(self, desc, ctx):
+ self.desc = desc
+ self._context = ctx
+ self._reported_stats = None
+
+ def update_metrics(self):
+ """Updates the metrics with values from this process."""
+ new_stats = self._context.get_resource_usage()
+ if self._reported_stats is None:
+ diff = new_stats
+ else:
+ diff = new_stats - self._reported_stats
+ self._reported_stats = new_stats
+
+ _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
+ _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
+ _background_process_db_txn_count.labels(self.desc).inc(
+ diff.db_txn_count,
+ )
+ _background_process_db_txn_duration.labels(self.desc).inc(
+ diff.db_txn_duration_sec,
+ )
+ _background_process_db_sched_duration.labels(self.desc).inc(
+ diff.db_sched_duration_sec,
+ )
+
+
+def run_as_background_process(desc, func, *args, **kwargs):
+ """Run the given function in its own logcontext, with resource metrics
+
+ This should be used to wrap processes which are fired off to run in the
+ background, instead of being associated with a particular request.
+
+ Args:
+ desc (str): a description for this background process type
+ func: a function, which may return a Deferred
+ args: positional args for func
+ kwargs: keyword args for func
+
+ Returns: None
+ """
+ @defer.inlineCallbacks
+ def run():
+ count = _background_process_counts.get(desc, 0)
+ _background_process_counts[desc] = count + 1
+ _background_process_start_count.labels(desc).inc()
+
+ with LoggingContext(desc) as context:
+ context.request = "%s-%i" % (desc, count)
+ proc = _BackgroundProcess(desc, context)
+ _background_processes.setdefault(desc, set()).add(proc)
+ try:
+ yield func(*args, **kwargs)
+ finally:
+ proc.update_metrics()
+ _background_processes[desc].remove(proc)
+
+ with PreserveLoggingContext():
+ run()
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
deleted file mode 100644
index e87b2b80a7..0000000000
--- a/synapse/metrics/metric.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from itertools import chain
-
-
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
- # flatten a list-of-lists
- return list(chain.from_iterable(map(func, items)))
-
-
-class BaseMetric(object):
-
- def __init__(self, name, labels=[]):
- self.name = name
- self.labels = labels # OK not to clone as we never write it
-
- def dimension(self):
- return len(self.labels)
-
- def is_scalar(self):
- return not len(self.labels)
-
- def _render_labelvalue(self, value):
- # TODO: some kind of value escape
- return '"%s"' % (value)
-
- def _render_key(self, values):
- if self.is_scalar():
- return ""
- return "{%s}" % (
- ",".join(["%s=%s" % (k, self._render_labelvalue(v))
- for k, v in zip(self.labels, values)])
- )
-
-
-class CounterMetric(BaseMetric):
- """The simplest kind of metric; one that stores a monotonically-increasing
- integer that counts events."""
-
- def __init__(self, *args, **kwargs):
- super(CounterMetric, self).__init__(*args, **kwargs)
-
- self.counts = {}
-
- # Scalar metrics are never empty
- if self.is_scalar():
- self.counts[()] = 0
-
- def inc_by(self, incr, *values):
- if len(values) != self.dimension():
- raise ValueError(
- "Expected as many values to inc() as labels (%d)" % (self.dimension())
- )
-
- # TODO: should assert that the tag values are all strings
-
- if values not in self.counts:
- self.counts[values] = incr
- else:
- self.counts[values] += incr
-
- def inc(self, *values):
- self.inc_by(1, *values)
-
- def render_item(self, k):
- return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
-
- def render(self):
- return map_concat(self.render_item, sorted(self.counts.keys()))
-
-
-class CallbackMetric(BaseMetric):
- """A metric that returns the numeric value returned by a callback whenever
- it is rendered. Typically this is used to implement gauges that yield the
- size or other state of some in-memory object by actively querying it."""
-
- def __init__(self, name, callback, labels=[]):
- super(CallbackMetric, self).__init__(name, labels=labels)
-
- self.callback = callback
-
- def render(self):
- value = self.callback()
-
- if self.is_scalar():
- return ["%s %.12g" % (self.name, value)]
-
- return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
- for k in sorted(value.keys())]
-
-
-class DistributionMetric(object):
- """A combination of an event counter and an accumulator, which counts
- both the number of events and accumulates the total value. Typically this
- could be used to keep track of method-running times, or other distributions
- of values that occur in discrete occurances.
-
- TODO(paul): Try to export some heatmap-style stats?
- """
-
- def __init__(self, name, *args, **kwargs):
- self.counts = CounterMetric(name + ":count", **kwargs)
- self.totals = CounterMetric(name + ":total", **kwargs)
-
- def inc_by(self, inc, *values):
- self.counts.inc(*values)
- self.totals.inc_by(inc, *values)
-
- def render(self):
- return self.counts.render() + self.totals.render()
-
-
-class CacheMetric(object):
- __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
-
- def __init__(self, name, size_callback, cache_name):
- self.name = name
- self.cache_name = cache_name
-
- self.hits = 0
- self.misses = 0
-
- self.size_callback = size_callback
-
- def inc_hits(self):
- self.hits += 1
-
- def inc_misses(self):
- self.misses += 1
-
- def render(self):
- size = self.size_callback()
- hits = self.hits
- total = self.misses + self.hits
-
- return [
- """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
- """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
- """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
- ]
-
-
-class MemoryUsageMetric(object):
- """Keeps track of the current memory usage, using psutil.
-
- The class will keep the current min/max/sum/counts of rss over the last
- WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
- """
-
- UPDATE_HZ = 2 # number of times to get memory per second
- WINDOW_SIZE_SEC = 30 # the size of the window in seconds
-
- def __init__(self, hs, psutil):
- clock = hs.get_clock()
- self.memory_snapshots = []
-
- self.process = psutil.Process()
-
- clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
-
- def _update_curr_values(self):
- max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
- self.memory_snapshots.append(self.process.memory_info().rss)
- self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
-
- def render(self):
- if not self.memory_snapshots:
- return []
-
- max_rss = max(self.memory_snapshots)
- min_rss = min(self.memory_snapshots)
- sum_rss = sum(self.memory_snapshots)
- len_rss = len(self.memory_snapshots)
-
- return [
- "process_psutil_rss:max %d" % max_rss,
- "process_psutil_rss:min %d" % min_rss,
- "process_psutil_rss:total %d" % sum_rss,
- "process_psutil_rss:count %d" % len_rss,
- ]
diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py
deleted file mode 100644
index 6fec3de399..0000000000
--- a/synapse/metrics/process_collector.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-
-TICKS_PER_SEC = 100
-BYTES_PER_PAGE = 4096
-
-HAVE_PROC_STAT = os.path.exists("/proc/stat")
-HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
-HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
-HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
-
-# Field indexes from /proc/self/stat, taken from the proc(5) manpage
-STAT_FIELDS = {
- "utime": 14,
- "stime": 15,
- "starttime": 22,
- "vsize": 23,
- "rss": 24,
-}
-
-
-stats = {}
-
-# In order to report process_start_time_seconds we need to know the
-# machine's boot time, because the value in /proc/self/stat is relative to
-# this
-boot_time = None
-if HAVE_PROC_STAT:
- with open("/proc/stat") as _procstat:
- for line in _procstat:
- if line.startswith("btime "):
- boot_time = int(line.split()[1])
-
-
-def update_resource_metrics():
- if HAVE_PROC_SELF_STAT:
- global stats
- with open("/proc/self/stat") as s:
- line = s.read()
- # line is PID (command) more stats go here ...
- raw_stats = line.split(") ", 1)[1].split(" ")
-
- for (name, index) in STAT_FIELDS.iteritems():
- # subtract 3 from the index, because proc(5) is 1-based, and
- # we've lost the first two fields in PID and COMMAND above
- stats[name] = int(raw_stats[index - 3])
-
-
-def _count_fds():
- # Not every OS will have a /proc/self/fd directory
- if not HAVE_PROC_SELF_FD:
- return 0
-
- return len(os.listdir("/proc/self/fd"))
-
-
-def register_process_collector(process_metrics):
- process_metrics.register_collector(update_resource_metrics)
-
- if HAVE_PROC_SELF_STAT:
- process_metrics.register_callback(
- "cpu_user_seconds_total",
- lambda: float(stats["utime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_system_seconds_total",
- lambda: float(stats["stime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_seconds_total",
- lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
- )
-
- process_metrics.register_callback(
- "virtual_memory_bytes",
- lambda: int(stats["vsize"])
- )
- process_metrics.register_callback(
- "resident_memory_bytes",
- lambda: int(stats["rss"]) * BYTES_PER_PAGE
- )
-
- process_metrics.register_callback(
- "start_time_seconds",
- lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
- )
-
- if HAVE_PROC_SELF_FD:
- process_metrics.register_callback(
- "open_fds",
- lambda: _count_fds()
- )
-
- if HAVE_PROC_SELF_LIMITS:
- def _get_max_fds():
- with open("/proc/self/limits") as limits:
- for line in limits:
- if not line.startswith("Max open files "):
- continue
- # Line is Max open files $SOFT $HARD
- return int(line.split()[3])
- return None
-
- process_metrics.register_callback(
- "max_fds",
- lambda: _get_max_fds()
- )
diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py
index 870f400600..9789359077 100644
--- a/synapse/metrics/resource.py
+++ b/synapse/metrics/resource.py
@@ -13,27 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-
-import synapse.metrics
-
+from prometheus_client.twisted import MetricsResource
METRICS_PREFIX = "/_synapse/metrics"
-
-class MetricsResource(Resource):
- isLeaf = True
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.hs = hs
-
- def render_GET(self, request):
- response = synapse.metrics.render_all()
-
- request.setHeader("Content-Type", "text/plain")
- request.setHeader("Content-Length", str(len(response)))
-
- # Encode as UTF-8 (default)
- return response.encode()
+__all__ = ["MetricsResource", "METRICS_PREFIX"]
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
new file mode 100644
index 0000000000..097c844d31
--- /dev/null
+++ b/synapse/module_api/__init__.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+
+class ModuleApi(object):
+ """A proxy object that gets passed to password auth providers so they
+ can register new users etc if necessary.
+ """
+ def __init__(self, hs, auth_handler):
+ self.hs = hs
+
+ self._store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = auth_handler
+
+ def get_user_by_req(self, req, allow_guest=False):
+ """Check the access_token provided for a request
+
+ Args:
+ req (twisted.web.server.Request): Incoming HTTP request
+ allow_guest (bool): True if guest users should be allowed. If this
+ is False, and the access token is for a guest user, an
+ AuthError will be thrown
+ Returns:
+ twisted.internet.defer.Deferred[synapse.types.Requester]:
+ the requester for this request
+ Raises:
+ synapse.api.errors.AuthError: if no user by that token exists,
+ or the token is invalid.
+ """
+ return self._auth.get_user_by_req(req, allow_guest)
+
+ def get_qualified_user_id(self, username):
+ """Qualify a user id, if necessary
+
+ Takes a user id provided by the user and adds the @ and :domain to
+ qualify it, if necessary
+
+ Args:
+ username (str): provided user id
+
+ Returns:
+ str: qualified @user:id
+ """
+ if username.startswith('@'):
+ return username
+ return UserID(username, self.hs.hostname).to_string()
+
+ def check_user_exists(self, user_id):
+ """Check if user exists.
+
+ Args:
+ user_id (str): Complete @user:id
+
+ Returns:
+ Deferred[str|None]: Canonical (case-corrected) user_id, or None
+ if the user is not registered.
+ """
+ return self._auth_handler.check_user_exists(user_id)
+
+ def register(self, localpart):
+ """Registers a new user with given localpart
+
+ Returns:
+ Deferred: a 2-tuple of (user_id, access_token)
+ """
+ reg = self.hs.get_handlers().registration_handler
+ return reg.register(localpart=localpart)
+
+ @defer.inlineCallbacks
+ def invalidate_access_token(self, access_token):
+ """Invalidate an access token for a user
+
+ Args:
+ access_token(str): access token
+
+ Returns:
+ twisted.internet.defer.Deferred - resolves once the access token
+ has been removed.
+
+ Raises:
+ synapse.api.errors.AuthError: the access token is invalid
+ """
+ # see if the access token corresponds to a device
+ user_info = yield self._auth.get_user_by_access_token(access_token)
+ device_id = user_info.get("device_id")
+ user_id = user_info["user"].to_string()
+ if device_id:
+ # delete the device, which will also delete its access tokens
+ yield self.hs.get_device_handler().delete_device(user_id, device_id)
+ else:
+ # no associated device. Just delete the access token.
+ yield self._auth_handler.delete_access_token(access_token)
+
+ def run_db_interaction(self, desc, func, *args, **kwargs):
+ """Run a function with a database connection
+
+ Args:
+ desc (str): description for the transaction, for metrics etc
+ func (func): function to be run. Passed a database cursor object
+ as well as *args and **kwargs
+ *args: positional args to be passed to func
+ **kwargs: named args to be passed to func
+
+ Returns:
+ Deferred[object]: result of func
+ """
+ return self._store.runInteraction(desc, func, *args, **kwargs)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 385208b574..e650c3e494 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -13,34 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+from collections import namedtuple
+
+from prometheus_client import Counter
+
from twisted.internet import defer
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
-
-from synapse.util import DeferredTimedOutError
+from synapse.metrics import LaterGauge
+from synapse.types import StreamToken
+from synapse.util.async import (
+ DeferredTimeoutError,
+ ObservableDeferred,
+ add_timeout_to_deferred,
+)
+from synapse.util.logcontext import PreserveLoggingContext, run_in_background
from synapse.util.logutils import log_function
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure
-from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
-import synapse.metrics
-
-from collections import namedtuple
-
-import logging
-
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
+notified_events_counter = Counter("synapse_notifier_notified_events", "")
-notified_events_counter = metrics.register_counter("notified_events")
-
-users_woken_by_stream_counter = metrics.register_counter(
- "users_woken_by_stream", labels=["stream"]
-)
+users_woken_by_stream_counter = Counter(
+ "synapse_notifier_users_woken_by_stream", "", ["stream"])
# TODO(paul): Should be shared somewhere
@@ -105,7 +105,7 @@ class _NotifierUserStream(object):
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
- users_woken_by_stream_counter.inc(stream_key)
+ users_woken_by_stream_counter.labels(stream_key).inc()
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
@@ -144,6 +144,7 @@ class _NotifierUserStream(object):
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class Notifier(object):
@@ -159,6 +160,7 @@ class Notifier(object):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
+ self.hs = hs
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
@@ -193,14 +195,14 @@ class Notifier(object):
all_user_streams.add(x)
return sum(stream.count_listeners() for stream in all_user_streams)
- metrics.register_callback("listeners", count_listeners)
+ LaterGauge("synapse_notifier_listeners", "", [], count_listeners)
- metrics.register_callback(
- "rooms",
+ LaterGauge(
+ "synapse_notifier_rooms", "", [],
lambda: count(bool, self.room_to_user_streams.values()),
)
- metrics.register_callback(
- "users",
+ LaterGauge(
+ "synapse_notifier_users", "", [],
lambda: len(self.user_to_user_stream),
)
@@ -250,14 +252,10 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- preserve_fn(self.appservice_handler.notify_interested_services)(
- room_stream_id
- )
+ run_in_background(self._notify_app_services, room_stream_id)
if self.federation_sender:
- preserve_fn(self.federation_sender.notify_new_events)(
- room_stream_id
- )
+ self.federation_sender.notify_new_events(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@@ -268,8 +266,15 @@ class Notifier(object):
rooms=[event.room_id],
)
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
+
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
- """ Used to inform listeners that something has happend event wise.
+ """ Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""
@@ -289,7 +294,7 @@ class Notifier(object):
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
- except:
+ except Exception:
logger.exception("Failed to notify listener")
self.notify_replication()
@@ -297,8 +302,7 @@ class Notifier(object):
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
- with PreserveLoggingContext():
- self.notify_replication()
+ self.notify_replication()
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -333,11 +337,13 @@ class Notifier(object):
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
listener = user_stream.new_listener(prev_token)
+ add_timeout_to_deferred(
+ listener.deferred,
+ (end_time - now) / 1000.,
+ self.hs.get_reactor(),
+ )
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
+ yield listener.deferred
current_token = user_stream.current_token
@@ -348,7 +354,7 @@ class Notifier(object):
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
- except DeferredTimedOutError:
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
@@ -516,8 +522,14 @@ class Notifier(object):
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
- for cb in self.replication_callbacks:
- preserve_fn(cb)()
+ # the callbacks may well outlast the current request, so we run
+ # them in the sentinel logcontext.
+ #
+ # (ideally it would be up to the callbacks to know if they were
+ # starting off background processes and drop the logcontext
+ # accordingly, but that requires more changes)
+ for cb in self.replication_callbacks:
+ cb()
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
@@ -547,13 +559,15 @@ class Notifier(object):
if end_time <= now:
break
+ add_timeout_to_deferred(
+ listener.deferred.addTimeout,
+ (end_time - now) / 1000.,
+ self.hs.get_reactor(),
+ )
try:
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
- except DeferredTimedOutError:
+ yield listener.deferred
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fe09d50d55..a5de75c48a 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
-from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
+from twisted.internet import defer
from synapse.util.metrics import Measure
-import logging
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
logger = logging.getLogger(__name__)
@@ -40,10 +40,6 @@ class ActionGenerator(object):
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
- actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
+ yield self.bulk_evaluator.action_for_event_by_user(
event, context
)
-
- context.push_actions = [
- (uid, actions) for uid, actions in actions_by_user.iteritems()
- ]
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 85effdfa46..8f0682c948 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
import copy
+from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+
def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules
@@ -38,7 +40,7 @@ def list_with_base_rules(rawrules):
rawrules = [r for r in rawrules if r['priority_class'] >= 0]
# shove the server default rules for each kind onto the end of each
- current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
+ current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
@@ -238,6 +240,28 @@ BASE_APPEND_OVERRIDE_RULES = [
}
]
},
+ {
+ 'rule_id': 'global/override/.m.rule.roomnotif',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern': '@room',
+ '_id': '_roomnotif_content',
+ },
+ {
+ 'kind': 'sender_notification_permission',
+ 'key': 'room',
+ '_id': '_roomnotif_pl',
+ },
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': True,
+ }
+ ]
+ }
]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 9a96e6fe8f..1d14d3639c 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +15,22 @@
# limitations under the License.
import logging
+from collections import namedtuple
-from twisted.internet import defer
+from six import iteritems, itervalues
-from .push_rule_evaluator import PushRuleEvaluatorForEvent
+from prometheus_client import Counter
+
+from twisted.internet import defer
-from synapse.visibility import filter_events_for_clients_context
from synapse.api.constants import EventTypes, Membership
-from synapse.util.caches.descriptors import cached
+from synapse.event_auth import get_user_power_level
+from synapse.state import POWER_KEY
from synapse.util.async import Linearizer
+from synapse.util.caches import register_cache
+from synapse.util.caches.descriptors import cached
-from collections import namedtuple
-
+from .push_rule_evaluator import PushRuleEvaluatorForEvent
logger = logging.getLogger(__name__)
@@ -33,6 +38,20 @@ logger = logging.getLogger(__name__)
rules_by_room = {}
+push_rules_invalidation_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "")
+push_rules_state_size_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "")
+
+# Measures whether we use the fast path of using state deltas, or if we have to
+# recalculate from scratch
+push_rules_delta_state_cache_metric = register_cache(
+ "cache",
+ "push_rules_delta_state_cache_metric",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
+)
+
+
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
@@ -41,6 +60,13 @@ class BulkPushRuleEvaluator(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ self.room_push_rule_cache_metrics = register_cache(
+ "cache",
+ "room_push_rule_cache",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
+ )
@defer.inlineCallbacks
def _get_rules_for_event(self, event, context):
@@ -79,37 +105,69 @@ class BulkPushRuleEvaluator(object):
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
- return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
+ return RulesForRoom(
+ self.hs, room_id, self._get_rules_for_room.cache,
+ self.room_push_rule_cache_metrics,
+ )
+
+ @defer.inlineCallbacks
+ def _get_power_levels_and_sender_level(self, event, context):
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ pl_event_id = prev_state_ids.get(POWER_KEY)
+ if pl_event_id:
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
+ pl_event = yield self.store.get_event(pl_event_id)
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, prev_state_ids, for_verification=False,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in itervalues(auth_events)
+ }
+
+ sender_level = get_user_power_level(event.sender, auth_events)
+
+ pl_event = auth_events.get(POWER_KEY)
+
+ defer.returnValue((pl_event.content if pl_event else {}, sender_level))
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
- """Given an event and context, evaluate the push rules and return
- the results
+ """Given an event and context, evaluate the push rules and insert the
+ results into the event_push_actions_staging table.
Returns:
- dict of user_id -> action
+ Deferred
"""
rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {}
- # None of these users can be peeking since this list of users comes
- # from the set of users in the room, so we know for sure they're all
- # actually in the room.
- user_tuples = [(u, False) for u in rules_by_user]
-
- filtered_by_user = yield filter_events_for_clients_context(
- self.store, user_tuples, [event], {event.event_id: context}
- )
-
room_members = yield self.store.get_joined_users_from_context(
event, context
)
- evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
+ (power_levels, sender_power_level) = (
+ yield self._get_power_levels_and_sender_level(event, context)
+ )
+
+ evaluator = PushRuleEvaluatorForEvent(
+ event, len(room_members), sender_power_level, power_levels,
+ )
condition_cache = {}
- for uid, rules in rules_by_user.iteritems():
+ for uid, rules in iteritems(rules_by_user):
+ if event.sender == uid:
+ continue
+
+ if not event.is_state():
+ is_ignored = yield self.store.is_ignored_by(event.sender, uid)
+ if is_ignored:
+ continue
+
display_name = None
profile_info = room_members.get(uid)
if profile_info:
@@ -121,13 +179,6 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
- filtered = filtered_by_user[uid]
- if len(filtered) == 0:
- continue
-
- if filtered[0].sender == uid:
- continue
-
for rule in rules:
if 'enabled' in rule and not rule['enabled']:
continue
@@ -138,9 +189,16 @@ class BulkPushRuleEvaluator(object):
if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify']
if actions and 'notify' in actions:
+ # Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
- defer.returnValue(actions_by_user)
+
+ # Mark in the DB staging area the push actions for users who should be
+ # notified for this event. (This will then get handled when we persist
+ # the event)
+ yield self.store.add_push_actions_to_staging(
+ event.event_id, actions_by_user,
+ )
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -170,17 +228,19 @@ class RulesForRoom(object):
the entire cache for the room.
"""
- def __init__(self, hs, room_id, rules_for_room_cache):
+ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
"""
Args:
hs (HomeServer)
room_id (str)
rules_for_room_cache(Cache): The cache object that caches these
RoomsForUser objects.
+ room_push_rule_cache_metrics (CacheMetric)
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
+ self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
self.linearizer = Linearizer(name="rules_for_room")
@@ -222,11 +282,19 @@ class RulesForRoom(object):
"""
state_group = context.state_group
+ if state_group and self.state_group == state_group:
+ logger.debug("Using cached rules for %r", self.room_id)
+ self.room_push_rule_cache_metrics.inc_hits()
+ defer.returnValue(self.rules_by_user)
+
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
+ self.room_push_rule_cache_metrics.inc_hits()
defer.returnValue(self.rules_by_user)
+ self.room_push_rule_cache_metrics.inc_misses()
+
ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.state_group == context.prev_group:
@@ -234,8 +302,13 @@ class RulesForRoom(object):
# results.
ret_rules_by_user = self.rules_by_user
current_state_ids = context.delta_ids
+
+ push_rules_delta_state_cache_metric.inc_hits()
else:
- current_state_ids = context.current_state_ids
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ push_rules_delta_state_cache_metric.inc_misses()
+
+ push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
@@ -282,6 +355,14 @@ class RulesForRoom(object):
yield self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
+ else:
+ # The push rules didn't change but lets update the cache anyway
+ self.update_cache(
+ self.sequence,
+ members={}, # There were no membership changes
+ rules_by_user=ret_rules_by_user,
+ state_group=state_group
+ )
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -324,7 +405,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
- for event_id in member_event_ids.itervalues():
+ for event_id in itervalues(member_event_ids):
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -332,7 +413,7 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
- user_id for user_id, membership in members.itervalues()
+ user_id for user_id, membership in itervalues(members)
if membership == Membership.JOIN
)
@@ -344,7 +425,7 @@ class RulesForRoom(object):
)
user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
)
logger.debug("With pushers: %r", user_ids)
@@ -365,7 +446,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
- item for item in rules_by_user.iteritems() if item[0] is not None
+ item for item in iteritems(rules_by_user) if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
@@ -380,6 +461,7 @@ class RulesForRoom(object):
self.state_group = object()
self.member_map = {}
self.rules_by_user = {}
+ push_rules_invalidation_counter.inc()
def update_cache(self, sequence, members, rules_by_user, state_group):
if sequence == self.sequence:
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index e0331b2d2d..ecbf364a5e 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.push.rulekinds import (
- PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
-)
-
import copy
+from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+
def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index a69dda7b09..d746371420 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -13,14 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer, reactor
-from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-
import logging
-from synapse.util.metrics import Measure
-from synapse.util.logcontext import LoggingContext
+from twisted.internet import defer
+from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.util.logcontext import LoggingContext
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -77,10 +76,13 @@ class EmailPusher(object):
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
- self.throttle_params = yield self.store.get_throttle_params_by_room(
- self.pusher_id
- )
- yield self._process()
+ try:
+ self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.pusher_id
+ )
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting email pusher")
def on_stop(self):
if self.timed_call:
@@ -121,7 +123,7 @@ class EmailPusher(object):
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
- except:
+ except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
@@ -196,7 +198,7 @@ class EmailPusher(object):
self.timed_call = None
if soonest_due_at is not None:
- self.timed_call = reactor.callLater(
+ self.timed_call = self.hs.get_reactor().callLater(
self.seconds_until(soonest_due_at), self.on_timer
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 8a5d473108..81e18bcf7d 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,21 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
-from synapse.push import PusherConfigException
+from prometheus_client import Counter
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-import logging
-import push_rule_evaluator
-import push_tools
-
+from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
+from . import push_rule_evaluator, push_tools
+
logger = logging.getLogger(__name__)
+http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "")
+
+http_push_failed_counter = Counter("synapse_http_httppusher_http_pushes_failed", "")
+
class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@@ -84,7 +89,10 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_started(self):
- yield self._process()
+ try:
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting http pusher")
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -131,7 +139,7 @@ class HttpPusher(object):
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
- except:
+ except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
@@ -151,9 +159,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
+ logger.info(
+ "Processing %i unprocessed push actions for %s starting at "
+ "stream_ordering %s",
+ len(unprocessed), self.name, self.last_stream_ordering,
+ )
+
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
+ http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -168,6 +183,7 @@ class HttpPusher(object):
self.failing_since
)
else:
+ http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
@@ -204,7 +220,9 @@ class HttpPusher(object):
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
- self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer)
+ self.timed_call = self.hs.get_reactor().callLater(
+ self.backoff_delay, self.on_timer
+ )
self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
break
@@ -244,6 +262,26 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
+ if self.data.get('format') == 'event_id_only':
+ d = {
+ 'notification': {
+ 'event_id': event.event_id,
+ 'room_id': event.room_id,
+ 'counts': {
+ 'unread': badge,
+ },
+ 'devices': [
+ {
+ 'app_id': self.app_id,
+ 'pushkey': self.pushkey,
+ 'pushkey_ts': long(self.pushkey_ts / 1000),
+ 'data': self.data_minus_url,
+ }
+ ]
+ }
+ }
+ defer.returnValue(d)
+
ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id
)
@@ -275,7 +313,7 @@ class HttpPusher(object):
if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id
- if not self.hs.config.push_redact_content and 'content' in event:
+ if self.hs.config.push_include_content and 'content' in event:
d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human
@@ -294,8 +332,11 @@ class HttpPusher(object):
defer.returnValue([])
try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
- except:
- logger.warn("Failed to push %s ", self.url)
+ except Exception:
+ logger.warn(
+ "Failed to push event %s to %s",
+ event.event_id, self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
@@ -304,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _send_badge(self, badge):
- logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+ logger.info("Sending updated badge count %d to %s", badge, self.name)
d = {
'notification': {
'id': '',
@@ -325,8 +366,11 @@ class HttpPusher(object):
}
try:
resp = yield self.http_client.post_json_get_json(self.url, d)
- except:
- logger.exception("Failed to push %s ", self.url)
+ except Exception:
+ logger.warn(
+ "Failed to send badge count to %s",
+ self.name, exc_info=True,
+ )
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b5cd9b426a..9d601208fd 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -13,30 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-from twisted.mail.smtp import sendmail
-
-import email.utils
import email.mime.multipart
-from email.mime.text import MIMEText
+import email.utils
+import logging
+import time
+import urllib
from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
-from synapse.util.async import concurrently_execute
+import bleach
+import jinja2
+
+from twisted.internet import defer
+from twisted.mail.smtp import sendmail
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import StoreError
from synapse.push.presentable_names import (
- calculate_room_name, name_from_member_event, descriptor_from_member_events
+ calculate_room_name,
+ descriptor_from_member_events,
+ name_from_member_event,
)
from synapse.types import UserID
-from synapse.api.errors import StoreError
-from synapse.api.constants import EventTypes
+from synapse.util.async import concurrently_execute
from synapse.visibility import filter_events_for_client
-import jinja2
-import bleach
-
-import time
-import urllib
-
-import logging
logger = logging.getLogger(__name__)
@@ -229,7 +230,8 @@ class Mailer(object):
if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
prev_messages = room_vars['notifs'][-1]['messages']
for message in notifvars['messages']:
- pm = filter(lambda pm: pm['id'] == message['id'], prev_messages)
+ pm = list(filter(lambda pm: pm['id'] == message['id'],
+ prev_messages))
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 277da3cd35..eef6e18c2e 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-import re
import logging
+import re
+
+from twisted.internet import defer
logger = logging.getLogger(__name__)
@@ -113,7 +113,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
# so find out who is in the room that isn't the user.
if "m.room.member" in room_state_bytype_ids:
member_events = yield store.get_events(
- room_state_bytype_ids["m.room.member"].values()
+ list(room_state_bytype_ids["m.room.member"].values())
)
all_members = [
ev for ev in member_events.values()
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 172c27c137..2bd321d530 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,8 @@
import logging
import re
+from six import string_types
+
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -29,6 +32,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count):
+ return _test_ineq_condition(condition, room_member_count)
+
+
+def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+ notif_level_key = condition.get('key')
+ if notif_level_key is None:
+ return False
+
+ notif_levels = power_levels.get('notifications', {})
+ room_notif_level = notif_levels.get(notif_level_key, 50)
+
+ return sender_power_level >= room_notif_level
+
+
+def _test_ineq_condition(condition, number):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
@@ -41,15 +59,15 @@ def _room_member_count(ev, condition, room_member_count):
rhs = int(rhs)
if ineq == '' or ineq == '==':
- return room_member_count == rhs
+ return number == rhs
elif ineq == '<':
- return room_member_count < rhs
+ return number < rhs
elif ineq == '>':
- return room_member_count > rhs
+ return number > rhs
elif ineq == '>=':
- return room_member_count >= rhs
+ return number >= rhs
elif ineq == '<=':
- return room_member_count <= rhs
+ return number <= rhs
else:
return False
@@ -65,9 +83,11 @@ def tweaks_for_actions(actions):
class PushRuleEvaluatorForEvent(object):
- def __init__(self, event, room_member_count):
+ def __init__(self, event, room_member_count, sender_power_level, power_levels):
self._event = event
self._room_member_count = room_member_count
+ self._sender_power_level = sender_power_level
+ self._power_levels = power_levels
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
@@ -81,6 +101,10 @@ class PushRuleEvaluatorForEvent(object):
return _room_member_count(
self._event, condition, self._room_member_count
)
+ elif condition['kind'] == 'sender_notification_permission':
+ return _sender_notification_permission(
+ self._event, condition, self._sender_power_level, self._power_levels,
+ )
else:
return True
@@ -128,7 +152,7 @@ class PushRuleEvaluatorForEvent(object):
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
-register_cache("regex_push_cache", regex_cache)
+register_cache("cache", "regex_push_cache", regex_cache)
def _glob_matches(glob, value, word_boundary=False):
@@ -183,7 +207,7 @@ def _glob_to_re(glob, word_boundary):
r,
)
if word_boundary:
- r = r"\b%s\b" % (r,)
+ r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
@@ -192,7 +216,7 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE)
elif word_boundary:
r = re.escape(glob)
- r = r"\b%s\b" % (r,)
+ r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
@@ -200,11 +224,23 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE)
+def _re_word_boundary(r):
+ """
+ Adds word boundary characters to the start and end of an
+ expression to require that the match occur as a whole word,
+ but do so respecting the fact that strings starting or ending
+ with non-word characters will change word boundaries.
+ """
+ # we can't use \b as it chokes on unicode. however \W seems to be okay
+ # as shorthand for [^0-9A-Za-z_].
+ return r"(^|\W)%s(\W|$)" % (r,)
+
+
def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, basestring):
+ if isinstance(value, string_types):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6835f54e97..8049c298c2 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -14,9 +14,8 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.push.presentable_names import (
- calculate_room_name, name_from_member_event
-)
+
+from synapse.push.presentable_names import calculate_room_name, name_from_member_event
@defer.inlineCallbacks
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 491f27bded..fcee6d9d7e 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from httppusher import HttpPusher
-
import logging
+
+from .httppusher import HttpPusher
+
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
try:
from synapse.push.emailpusher import EmailPusher
from synapse.push.mailer import Mailer, load_jinja2_templates
-except:
+except Exception:
pass
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 43cb6e9c01..36bb5bbc65 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from .pusher import PusherFactory
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.push.pusher import PusherFactory
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -103,23 +102,28 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
- def remove_pushers_by_user(self, user_id, except_access_token_id=None):
- all = yield self.store.get_all_pushers()
- logger.info(
- "Removing all pushers for user %s except access tokens id %r",
- user_id, except_access_token_id
- )
- for p in all:
- if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+ def remove_pushers_by_access_token(self, user_id, access_tokens):
+ """Remove the pushers for a given user corresponding to a set of
+ access_tokens.
+
+ Args:
+ user_id (str): user to remove pushers for
+ access_tokens (Iterable[int]): access token *ids* to remove pushers
+ for
+ """
+ tokens = set(access_tokens)
+ for p in (yield self.store.get_pushers_by_user_id(user_id)):
+ if p['access_token'] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
- yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(
+ p['app_id'], p['pushkey'], p['user_name'],
+ )
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
- yield run_on_reactor()
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
@@ -131,18 +135,20 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_notifications)(
- min_stream_id, max_stream_id
+ run_in_background(
+ p.on_new_notifications,
+ min_stream_id, max_stream_id,
)
)
- yield preserve_context_over_deferred(defer.gatherResults(deferreds))
- except:
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
+ except Exception:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
- yield run_on_reactor()
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
@@ -158,11 +164,16 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+ run_in_background(
+ p.on_new_receipts,
+ min_stream_id, max_stream_id,
+ )
)
- yield preserve_context_over_deferred(defer.gatherResults(deferreds))
- except:
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
+ except Exception:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
@@ -188,7 +199,7 @@ class PusherPool:
for pusherdict in pushers:
try:
p = self.pusher_factory.create_pusher(pusherdict)
- except:
+ except Exception:
logger.exception("Couldn't start a pusher: caught Exception")
continue
if p:
@@ -201,7 +212,7 @@ class PusherPool:
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
- preserve_fn(p.on_started)()
+ run_in_background(p.on_started)
logger.info("Started pushers")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index ed7f1c89ad..987eec3ef2 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,5 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,36 +19,52 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
+# this dict maps from python package name to a list of modules we expect it to
+# provide.
+#
+# the key is a "requirement specifier", as used as a parameter to `pip
+# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
+#
+# the value is a sequence of strings; each entry should be the name of the
+# python module, optionally followed by a version assertion which can be either
+# ">=<ver>" or "==<ver>".
+#
+# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
+# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
- "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
+ "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
- "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
+ "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"],
- "pyopenssl>=0.14": ["OpenSSL>=0.14"],
+
+ # We use crypto.get_elliptic_curve which is only supported in >=0.15
+ "pyopenssl>=0.15": ["OpenSSL>=0.15"],
+
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"daemonize": ["daemonize"],
- "py-bcrypt": ["bcrypt"],
+ "bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
- "ujson": ["ujson"],
- "blist": ["blist"],
- "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
+ "sortedcontainers": ["sortedcontainers"],
+ "pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
+ "six": ["six"],
+ "prometheus_client": ["prometheus_client"],
+ "attrs": ["attr"],
+ "netaddr>=0.7.18": ["netaddr"],
}
+
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
},
- "preview_url": {
- "netaddr>=0.7.18": ["netaddr"],
- },
"email.enable_notifs": {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
@@ -58,6 +75,9 @@ CONDITIONAL_REQUIREMENTS = {
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
+ "affinity": {
+ "affinity": ["affinity"],
+ },
}
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
new file mode 100644
index 0000000000..589ee94c66
--- /dev/null
+++ b/synapse/replication/http/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import JsonResource
+from synapse.replication.http import membership, send_event
+
+REPLICATION_PREFIX = "/_synapse/replication"
+
+
+class ReplicationRestResource(JsonResource):
+ def __init__(self, hs):
+ JsonResource.__init__(self, hs, canonical_json=False)
+ self.register_servlets(hs)
+
+ def register_servlets(self, hs):
+ send_event.register_servlets(hs, self)
+ membership.register_servlets(hs, self)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
new file mode 100644
index 0000000000..6bfc8a5b89
--- /dev/null
+++ b/synapse/replication/http/membership.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from twisted.internet import defer
+
+from synapse.api.errors import MatrixCodeMessageException, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import Requester, UserID
+from synapse.util.distributor import user_joined_room, user_left_room
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def remote_join(client, host, port, requester, remote_room_hosts,
+ room_id, user_id, content):
+ """Ask the master to do a remote join for the given user to the given room
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ remote_room_hosts (list[str]): Servers to try and join via
+ room_id (str)
+ user_id (str)
+ content (dict): The event content to use for the join event
+
+ Returns:
+ Deferred
+ """
+ uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "room_id": room_id,
+ "user_id": user_id,
+ "content": content,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def remote_reject_invite(client, host, port, requester, remote_room_hosts,
+ room_id, user_id):
+ """Ask master to reject the invite for the user and room.
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ remote_room_hosts (list[str]): Servers to try and reject via
+ room_id (str)
+ user_id (str)
+
+ Returns:
+ Deferred
+ """
+ uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "room_id": room_id,
+ "user_id": user_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def get_or_register_3pid_guest(client, host, port, requester,
+ medium, address, inviter_user_id):
+ """Ask the master to get/create a guest account for given 3PID.
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+
+ Returns:
+ Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+ 3PID guest account.
+ """
+
+ uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
+
+ payload = {
+ "requester": requester.serialize(),
+ "medium": medium,
+ "address": address,
+ "inviter_user_id": inviter_user_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def notify_user_membership_change(client, host, port, user_id, room_id, change):
+ """Notify master that a user has joined or left the room
+
+ Args:
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication.
+ user_id (str)
+ room_id (str)
+ change (str): Either "join" or "left"
+
+ Returns:
+ Deferred
+ """
+ assert change in ("joined", "left")
+
+ uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
+
+ payload = {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+
+ try:
+ result = yield client.post_json_get_json(uri, payload)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+class ReplicationRemoteJoinRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
+
+ def __init__(self, hs):
+ super(ReplicationRemoteJoinRestServlet, self).__init__()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ room_id = content["room_id"]
+ user_id = content["user_id"]
+ event_content = content["content"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "remote_join: %s into room: %s",
+ user_id, room_id,
+ )
+
+ yield self.federation_handler.do_invite_join(
+ remote_room_hosts,
+ room_id,
+ user_id,
+ event_content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class ReplicationRemoteRejectInviteRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
+
+ def __init__(self, hs):
+ super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
+
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ remote_room_hosts = content["remote_room_hosts"]
+ room_id = content["room_id"]
+ user_id = content["user_id"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "remote_reject_invite: %s out of room: %s",
+ user_id, room_id,
+ )
+
+ try:
+ event = yield self.federation_handler.do_remotely_reject_invite(
+ remote_room_hosts,
+ room_id,
+ user_id,
+ )
+ ret = event.get_pdu_json()
+ except Exception as e:
+ # if we were unable to reject the exception, just mark
+ # it as rejected on our end and plough ahead.
+ #
+ # The 'except' clause is very broad, but we need to
+ # capture everything from DNS failures upwards
+ #
+ logger.warn("Failed to reject invite: %s", e)
+
+ yield self.store.locally_reject_invite(
+ user_id, room_id
+ )
+ ret = {}
+
+ defer.returnValue((200, ret))
+
+
+class ReplicationRegister3PIDGuestRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
+
+ def __init__(self, hs):
+ super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
+
+ self.registeration_handler = hs.get_handlers().registration_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ medium = content["medium"]
+ address = content["address"]
+ inviter_user_id = content["inviter_user_id"]
+
+ requester = Requester.deserialize(self.store, content["requester"])
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info("get_or_register_3pid_guest: %r", content)
+
+ ret = yield self.registeration_handler.get_or_register_3pid_guest(
+ medium, address, inviter_user_id,
+ )
+
+ defer.returnValue((200, ret))
+
+
+class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
+
+ def __init__(self, hs):
+ super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
+
+ self.registeration_handler = hs.get_handlers().registration_handler
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.distributor = hs.get_distributor()
+
+ def on_POST(self, request, change):
+ content = parse_json_object_from_request(request)
+
+ user_id = content["user_id"]
+ room_id = content["room_id"]
+
+ logger.info("user membership change: %s in %s", user_id, room_id)
+
+ user = UserID.from_string(user_id)
+
+ if change == "joined":
+ user_joined_room(self.distributor, user, room_id)
+ elif change == "left":
+ user_left_room(self.distributor, user, room_id)
+ else:
+ raise Exception("Unrecognized change: %r", change)
+
+ return (200, {})
+
+
+def register_servlets(hs, http_server):
+ ReplicationRemoteJoinRestServlet(hs).register(http_server)
+ ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
+ ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
+ ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
new file mode 100644
index 0000000000..5227bc333d
--- /dev/null
+++ b/synapse/replication/http/send_event.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from twisted.internet import defer
+
+from synapse.api.errors import (
+ CodeMessageException,
+ MatrixCodeMessageException,
+ SynapseError,
+)
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import Requester, UserID
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def send_event_to_master(clock, store, client, host, port, requester, event, context,
+ ratelimit, extra_users):
+ """Send event to be handled on the master
+
+ Args:
+ clock (synapse.util.Clock)
+ store (DataStore)
+ client (SimpleHttpClient)
+ host (str): host of master
+ port (int): port on master listening for HTTP replication
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
+ uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
+ host, port, event.event_id,
+ )
+
+ serialized_context = yield context.serialize(event, store)
+
+ payload = {
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ "requester": requester.serialize(),
+ "ratelimit": ratelimit,
+ "extra_users": [u.to_string() for u in extra_users],
+ }
+
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed on
+ # the master, and so whether we should clean up or not.
+ while True:
+ try:
+ result = yield client.put_json(uri, payload)
+ break
+ except CodeMessageException as e:
+ if e.code != 504:
+ raise
+
+ logger.warn("send_event request timed out")
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ yield clock.sleep(1)
+ except MatrixCodeMessageException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise SynapseError(e.code, e.msg, e.errcode)
+ defer.returnValue(result)
+
+
+class ReplicationSendEventRestServlet(RestServlet):
+ """Handles events newly created on workers, including persisting and
+ notifying.
+
+ The API looks like:
+
+ POST /_synapse/replication/send_event/:event_id
+
+ {
+ "event": { .. serialized event .. },
+ "internal_metadata": { .. serialized internal_metadata .. },
+ "rejected_reason": .., // The event.rejected_reason field
+ "context": { .. serialized event context .. },
+ "requester": { .. serialized requester .. },
+ "ratelimit": true,
+ "extra_users": [],
+ }
+ """
+ PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
+
+ def __init__(self, hs):
+ super(ReplicationSendEventRestServlet, self).__init__()
+
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ # The responses are tiny, so we may as well cache them for a while
+ self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
+
+ def on_PUT(self, request, event_id):
+ return self.response_cache.wrap(
+ event_id,
+ self._handle_request,
+ request
+ )
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request):
+ with Measure(self.clock, "repl_send_event_parse"):
+ content = parse_json_object_from_request(request)
+
+ event_dict = content["event"]
+ internal_metadata = content["internal_metadata"]
+ rejected_reason = content["rejected_reason"]
+ event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ requester = Requester.deserialize(self.store, content["requester"])
+ context = yield EventContext.deserialize(self.store, content["context"])
+
+ ratelimit = content["ratelimit"]
+ extra_users = [UserID.from_string(u) for u in content["extra_users"]]
+
+ if requester.user:
+ request.authenticated_entity = requester.user.to_string()
+
+ logger.info(
+ "Got event to send with ID: %s into room: %s",
+ event.event_id, event.room_id,
+ )
+
+ yield self.event_creation_handler.persist_and_notify_client_event(
+ requester, event, context,
+ ratelimit=ratelimit,
+ extra_users=extra_users,
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReplicationSendEventRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index b962641166..3f7be74e02 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
-import logging
-
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
- super(BaseSlavedStore, self).__init__(hs)
+ super(BaseSlavedStore, self).__init__(db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index efbd87918e..d9ba6d69b1 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,50 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.account_data import AccountDataStore
-from synapse.storage.tags import TagsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.tags import TagsWorkerStore
-class SlavedAccountDataStore(BaseSlavedStore):
+class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache",
- self._account_data_id_gen.get_current_token(),
- )
-
- get_account_data_for_user = (
- AccountDataStore.__dict__["get_account_data_for_user"]
- )
-
- get_global_account_data_by_type_for_users = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
- )
- get_global_account_data_by_type_for_user = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
- )
-
- get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
- get_tags_for_room = (
- DataStore.get_tags_for_room.__func__
- )
- get_account_data_for_room = (
- DataStore.get_account_data_for_room.__func__
- )
-
- get_updated_tags = DataStore.get_updated_tags.__func__
- get_updated_account_data_for_user = (
- DataStore.get_updated_account_data_for_user.__func__
- )
+ super(SlavedAccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
@@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore):
(row.data_type, row.user_id,)
)
self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type,),
+ )
self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index 0d3f31a50c..b53a4c6bd1 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,33 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.config.appservice import load_appservices
-from synapse.storage.appservice import _make_exclusive_regex
+from synapse.storage.appservice import (
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
+)
-class SlavedApplicationServiceStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
- self.services_cache = load_appservices(
- hs.config.server_name,
- hs.config.app_service_config_files
- )
- self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
-
- get_app_service_by_token = DataStore.get_app_service_by_token.__func__
- get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
- get_app_services = DataStore.get_app_services.__func__
- get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
- create_appservice_txn = DataStore.create_appservice_txn.__func__
- get_appservices_by_state = DataStore.get_appservices_by_state.__func__
- get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
- _get_last_txn = DataStore._get_last_txn.__func__
- complete_appservice_txn = DataStore.complete_appservice_txn.__func__
- get_appservice_state = DataStore.get_appservice_state.__func__
- set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
- set_appservice_state = DataStore.set_appservice_state.__func__
- get_if_app_services_interested_in_user = (
- DataStore.get_if_app_services_interested_in_user.__func__
- )
+class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore):
+ pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 65250285e8..60641f1a49 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
+from ._base import BaseSlavedStore
+
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
@@ -29,9 +30,8 @@ class SlavedClientIpStore(BaseSlavedStore):
max_entries=50000 * CACHE_SIZE_FACTOR,
)
- def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
+ def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
- user_id = user.to_string()
key = (user_id, access_token, ip)
try:
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6f3fb64770..87eaa53004 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 7687867aee..8206a988f7 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 7301d885f2..1d1d48709a 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.directory import DirectoryWorkerStore
+
from ._base import BaseSlavedStore
-from synapse.storage.directory import DirectoryStore
-class DirectoryStore(BaseSlavedStore):
- get_aliases_for_room = DirectoryStore.__dict__[
- "get_aliases_for_room"
- ]
+class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 94ebbffc1b..bdb5eee4af 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,20 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
+import logging
from synapse.api.constants import EventTypes
-from synapse.storage import DataStore
-from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.event_push_actions import EventPushActionsStore
-from synapse.storage.state import StateStore
-from synapse.storage.stream import StreamStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-import logging
+from synapse.storage.event_federation import EventFederationWorkerStore
+from synapse.storage.event_push_actions import EventPushActionsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage.state import StateGroupWorkerStore
+from synapse.storage.stream import StreamWorkerStore
+from synapse.storage.user_erasure_store import UserErasureWorkerStore
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -39,163 +40,34 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(BaseSlavedStore):
+class SlavedEventStore(EventFederationWorkerStore,
+ RoomMemberWorkerStore,
+ EventPushActionsWorkerStore,
+ StreamWorkerStore,
+ EventsWorkerStore,
+ StateGroupWorkerStore,
+ SignatureWorkerStore,
+ UserErasureWorkerStore,
+ BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
- events_max = self._stream_id_gen.get_current_token()
- event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
- entity_column="room_id",
- stream_column="stream_ordering",
- max_value=events_max,
- )
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
- prefilled_cache=event_cache_prefill,
- )
- self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
- )
- self.stream_ordering_month_ago = 0
- self._stream_order_on_start = self.get_room_max_stream_ordering()
+ super(SlavedEventStore, self).__init__(db_conn, hs)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
- get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
- get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
- get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
- get_users_who_share_room_with_user = (
- RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
- )
- get_latest_event_ids_in_room = EventFederationStore.__dict__[
- "get_latest_event_ids_in_room"
- ]
- get_invited_rooms_for_user = RoomMemberStore.__dict__[
- "get_invited_rooms_for_user"
- ]
- get_unread_event_push_actions_by_room_for_user = (
- EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
- )
- _get_unread_counts_by_receipt_txn = (
- DataStore._get_unread_counts_by_receipt_txn.__func__
- )
- _get_unread_counts_by_pos_txn = (
- DataStore._get_unread_counts_by_pos_txn.__func__
- )
- _get_state_group_for_events = (
- StateStore.__dict__["_get_state_group_for_events"]
- )
- _get_state_group_for_event = (
- StateStore.__dict__["_get_state_group_for_event"]
- )
- _get_state_groups_from_groups = (
- StateStore.__dict__["_get_state_groups_from_groups"]
- )
- _get_state_groups_from_groups_txn = (
- DataStore._get_state_groups_from_groups_txn.__func__
- )
- get_recent_event_ids_for_room = (
- StreamStore.__dict__["get_recent_event_ids_for_room"]
- )
- get_current_state_ids = (
- StateStore.__dict__["get_current_state_ids"]
- )
- get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
- _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
- has_room_changed_since = DataStore.has_room_changed_since.__func__
-
- get_unread_push_actions_for_user_in_range_for_http = (
- DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
- )
- get_unread_push_actions_for_user_in_range_for_email = (
- DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
- )
- get_push_action_users_in_range = (
- DataStore.get_push_action_users_in_range.__func__
- )
- get_event = DataStore.get_event.__func__
- get_events = DataStore.get_events.__func__
- get_rooms_for_user_where_membership_is = (
- DataStore.get_rooms_for_user_where_membership_is.__func__
- )
- get_membership_changes_for_user = (
- DataStore.get_membership_changes_for_user.__func__
- )
- get_room_events_max_id = DataStore.get_room_events_max_id.__func__
- get_room_events_stream_for_room = (
- DataStore.get_room_events_stream_for_room.__func__
- )
- get_events_around = DataStore.get_events_around.__func__
- get_state_for_event = DataStore.get_state_for_event.__func__
- get_state_for_events = DataStore.get_state_for_events.__func__
- get_state_groups = DataStore.get_state_groups.__func__
- get_state_groups_ids = DataStore.get_state_groups_ids.__func__
- get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
- get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
- get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
- get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
- _get_joined_users_from_context = (
- RoomMemberStore.__dict__["_get_joined_users_from_context"]
- )
-
- get_joined_hosts = DataStore.get_joined_hosts.__func__
- _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
-
- get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
- get_room_events_stream_for_rooms = (
- DataStore.get_room_events_stream_for_rooms.__func__
- )
- is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
- get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
-
- _set_before_and_after = staticmethod(DataStore._set_before_and_after)
-
- _get_events = DataStore._get_events.__func__
- _get_events_from_cache = DataStore._get_events_from_cache.__func__
-
- _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
- _enqueue_events = DataStore._enqueue_events.__func__
- _do_fetch = DataStore._do_fetch.__func__
- _fetch_event_rows = DataStore._fetch_event_rows.__func__
- _get_event_from_row = DataStore._get_event_from_row.__func__
- _get_rooms_for_user_where_membership_is_txn = (
- DataStore._get_rooms_for_user_where_membership_is_txn.__func__
- )
- _get_state_for_groups = DataStore._get_state_for_groups.__func__
- _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
- _get_events_around_txn = DataStore._get_events_around_txn.__func__
- _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
-
- get_backfill_events = DataStore.get_backfill_events.__func__
- _get_backfill_events = DataStore._get_backfill_events.__func__
- get_missing_events = DataStore.get_missing_events.__func__
- _get_missing_events = DataStore._get_missing_events.__func__
-
- get_auth_chain = DataStore.get_auth_chain.__func__
- get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
- _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
-
- get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
-
- get_forward_extremeties_for_room = (
- DataStore.get_forward_extremeties_for_room.__func__
- )
- _get_forward_extremeties_for_room = (
- EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
- )
-
- get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
-
- get_federation_out_pos = DataStore.get_federation_out_pos.__func__
- update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 819ed62881..456a14cd5c 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
+from ._base import BaseSlavedStore
+
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
new file mode 100644
index 0000000000..5777f07c8d
--- /dev/null
+++ b/synapse/replication/slave/storage/groups.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
+
+class SlavedGroupServerStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+
+ self.hs = hs
+
+ self._group_updates_id_gen = SlavedIdTracker(
+ db_conn, "local_group_updates", "stream_id",
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+ )
+
+ get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
+ get_group_stream_token = DataStore.get_group_stream_token.__func__
+ get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
+
+ def stream_positions(self):
+ result = super(SlavedGroupServerStore, self).stream_positions()
+ result["groups"] = self._group_updates_id_gen.get_current_token()
+ return result
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "groups":
+ self._group_updates_id_gen.advance(token)
+ for row in rows:
+ self._group_updates_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
+
+ return super(SlavedGroupServerStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index dd2ae49e48..05ed168463 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
+from ._base import BaseSlavedStore
+
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index cfb9280181..80b744082a 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
from synapse.storage.presence import PresenceStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
new file mode 100644
index 0000000000..46c28d4171
--- /dev/null
+++ b/synapse/replication/slave/storage/profile.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.storage.profile import ProfileWorkerStore
+
+
+class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 83e880fdd2..f0200c1e98 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,31 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .events import SlavedEventStore
+from synapse.storage.push_rule import PushRulesWorkerStore
+
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from .events import SlavedEventStore
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
- super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache",
- self._push_rules_stream_id_gen.get_current_token(),
- )
-
- get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
- get_push_rules_enabled_for_user = (
- PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
- )
- have_push_rules_changed_for_user = (
- DataStore.have_push_rules_changed_for_user.__func__
- )
+ super(SlavedPushRuleStore, self).__init__(db_conn, hs)
def get_push_rules_stream_token(self):
return (
@@ -45,6 +33,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(),
)
+ def get_max_push_rules_stream_id(self):
+ return self._push_rules_stream_id_gen.get_current_token()
+
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 4e8d68ece9..3b2213c0d4 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.pusher import PusherWorkerStore
+
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-
-class SlavedPusherStore(BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
@@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
extra_tables=[("deleted_pushers", "stream_id")],
)
- get_all_pushers = DataStore.get_all_pushers.__func__
- get_pushers_by = DataStore.get_pushers_by.__func__
- get_pushers_by_app_id_and_pushkey = (
- DataStore.get_pushers_by_app_id_and_pushkey.__func__
- )
- _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
-
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index b371574ece..ed12342f40 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.receipts import ReceiptsWorkerStore
+
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.receipts import ReceiptsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
@@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedReceiptsStore(BaseSlavedStore):
+class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedReceiptsStore, self).__init__(db_conn, hs)
-
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
- self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
- )
-
- get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
- get_linearized_receipts_for_room = (
- ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
- )
- _get_linearized_receipts_for_rooms = (
- ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
- )
- get_last_receipt_event_id_for_user = (
- ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
- )
-
- get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
- get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+ super(SlavedReceiptsStore, self).__init__(db_conn, hs)
- get_linearized_receipts_for_rooms = (
- DataStore.get_linearized_receipts_for_rooms.__func__
- )
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
@@ -67,10 +49,12 @@ class SlavedReceiptsStore(BaseSlavedStore):
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
- self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts":
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index e27c7332d2..408d91df1c 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,21 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.registration import RegistrationStore
-
+from synapse.storage.registration import RegistrationWorkerStore
-class SlavedRegistrationStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedRegistrationStore, self).__init__(db_conn, hs)
+from ._base import BaseSlavedStore
- # TODO: use the cached version and invalidate deleted tokens
- get_user_by_access_token = RegistrationStore.__dict__[
- "get_user_by_access_token"
- ]
- _query_for_auth = DataStore._query_for_auth.__func__
- get_user_by_id = RegistrationStore.__dict__[
- "get_user_by_id"
- ]
+class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index f510384033..0cb474928c 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,33 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.room import RoomWorkerStore
+
from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker
-class RoomStore(BaseSlavedStore):
+class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
- get_public_room_ids = DataStore.get_public_room_ids.__func__
- get_current_public_room_stream_id = (
- DataStore.get_current_public_room_stream_id.__func__
- )
- get_public_room_ids_at_stream_id = (
- RoomStore.__dict__["get_public_room_ids_at_stream_id"]
- )
- get_public_room_ids_at_stream_id_txn = (
- DataStore.get_public_room_ids_at_stream_id_txn.__func__
- )
- get_published_at_stream_id_txn = (
- DataStore.get_published_at_stream_id_txn.__func__
- )
- get_public_room_changes = DataStore.get_public_room_changes.__func__
+ def get_current_public_room_stream_id(self):
+ return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index fbb58f35da..9c9a5eadd9 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
+from ._base import BaseSlavedStore
+
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..e592ab57bf 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
"""A replication client for use by synapse workers.
"""
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
- FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ FederationAckCommand,
+ InvalidateCacheCommand,
+ RemovePusherCommand,
UserIpCommand,
+ UserSyncCommand,
)
from .protocol import ClientReplicationStreamProtocol
-import logging
-
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
- reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- reactor.connectTCP(host, port, factory)
+ hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index a009214e43..f3908df642 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,8 +19,14 @@ allowed to be sent by which side.
"""
import logging
-import ujson as json
+import platform
+if platform.python_implementation() == "PyPy":
+ import json
+ _json_encoder = json.JSONEncoder()
+else:
+ import simplejson as json
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
@@ -107,7 +113,7 @@ class RdataCommand(Command):
return " ".join((
self.stream_name,
str(self.token) if self.token is not None else "batch",
- json.dumps(self.row),
+ _json_encoder.encode(self.row),
))
@@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command):
return cls(cache_func, json.loads(keys_json))
def to_line(self):
- return " ".join((self.cache_func, json.dumps(self.keys)))
+ return " ".join((
+ self.cache_func, _json_encoder.encode(self.keys),
+ ))
class UserIpCommand(Command):
@@ -323,14 +331,18 @@ class UserIpCommand(Command):
@classmethod
def from_line(cls, line):
- user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
+ user_id, jsn = line.split(" ", 1)
+
+ access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
- return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
+ return cls(
+ user_id, access_token, ip, user_agent, device_id, last_seen
+ )
def to_line(self):
- return " ".join((
- self.user_id, self.access_token, self.ip, self.device_id,
- str(self.last_seen), self.user_agent,
+ return self.user_id + " " + _json_encoder.encode((
+ self.access_token, self.ip, self.user_agent, self.device_id,
+ self.last_seen,
))
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 062272f8dd..dec5ac0913 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire::
* connection closed by server *
"""
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from commands import (
- COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
- ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
- NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from streams import STREAMS_MAP
-
+from synapse.metrics import LaterGauge
from synapse.util.stringutils import random_string
-from synapse.metrics.metric import CounterMetric
-
-import logging
-import synapse.metrics
-import struct
-import fcntl
-
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-connection_close_counter = metrics.register_counter(
- "close_reason", labels=["reason_type"],
+from .commands import (
+ COMMAND_MAP,
+ VALID_CLIENT_COMMANDS,
+ VALID_SERVER_COMMANDS,
+ ErrorCommand,
+ NameCommand,
+ PingCommand,
+ PositionCommand,
+ RdataCommand,
+ ReplicateCommand,
+ ServerCommand,
+ SyncCommand,
+ UserSyncCommand,
)
+from .streams import STREAMS_MAP
+connection_close_counter = Counter(
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = CounterMetric(
- "inbound_commands", labels=["command"],
- )
- self.outbound_commands_counter = CounterMetric(
- "outbound_commands", labels=["command"],
- )
+ self.inbound_commands_counter = defaultdict(int)
+ self.outbound_commands_counter = defaultdict(int)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
- self.inbound_commands_counter.inc(cmd_name)
+ self.inbound_commands_counter[cmd_name] = (
+ self.inbound_commands_counter[cmd_name] + 1)
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
becoming full.
"""
if self.state == ConnectionStates.CLOSED:
- logger.info("[%s] Not sending, connection closed", self.id())
+ logger.debug("[%s] Not sending, connection closed", self.id())
return
if do_buffer and self.state != ConnectionStates.ESTABLISHED:
self._queue_command(cmd)
return
- self.outbound_commands_counter.inc(cmd.NAME)
-
+ self.outbound_commands_counter[cmd.NAME] = (
+ self.outbound_commands_counter[cmd.NAME] + 1)
string = "%s %s" % (cmd.NAME, cmd.to_line(),)
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
@@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
- connection_close_counter.inc(reason.type.__name__)
+ connection_close_counter.labels(reason.type.__name__).inc()
else:
- connection_close_counter.inc(reason.__class__.__name__)
+ connection_close_counter.labels(reason.__class__.__name__).inc()
try:
# Remove us from list of connections to be monitored
@@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in self.streamer.streams_by_name.iterkeys():
+ for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token)
else:
self.subscribe_to_stream(stream_name, token)
@@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+ for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
@@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.labels(stream_name).inc()
+
try:
- row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+ row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
- self.id(), cmd.stream_name, cmd.row
+ self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
- self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+ self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(cmd.stream_name, [])
+ rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+ self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token)
@@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
-metrics.register_callback(
- "pending_commands",
+pending_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_pending_commands",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands)
- for p in connected_connections
+ (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
},
- labels=["name", "conn_id"],
)
@@ -580,13 +588,13 @@ def transport_buffer_size(protocol):
return 0
-metrics.register_callback(
- "transport_send_buffer",
+transport_send_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p)
- for p in connected_connections
+ (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
},
- labels=["name", "conn_id"],
)
@@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True):
return 0
-metrics.register_callback(
- "transport_kernel_send_buffer",
+tcp_transport_kernel_send_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
- labels=["name", "conn_id"],
)
-metrics.register_callback(
- "transport_kernel_read_buffer",
+tcp_transport_kernel_read_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
- labels=["name", "conn_id"],
)
-metrics.register_callback(
- "inbound_commands",
+tcp_inbound_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.inbound_commands_counter.counts.iteritems()
+ for k, count in iteritems(p.inbound_commands_counter)
},
- labels=["command", "name", "conn_id"],
)
-metrics.register_callback(
- "outbound_commands",
+tcp_outbound_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.outbound_commands_counter.counts.iteritems()
+ for k, count in iteritems(p.outbound_commands_counter)
},
- labels=["command", "name", "conn_id"],
+)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 3ea3ca5a6f..611fb66e1d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,27 +15,29 @@
"""The server side of the replication stream.
"""
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
-from streams import STREAMS_MAP, FederationStream
-from protocol import ServerReplicationStreamProtocol
+from six import itervalues
-from synapse.util.metrics import Measure, measure_func
+from prometheus_client import Counter
-import logging
-import synapse.metrics
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
+from synapse.metrics import LaterGauge
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
-metrics = synapse.metrics.get_metrics_for(__name__)
-stream_updates_counter = metrics.register_counter(
- "stream_updates", labels=["stream_name"]
-)
-user_sync_counter = metrics.register_counter("user_sync")
-federation_ack_counter = metrics.register_counter("federation_ack")
-remove_pusher_counter = metrics.register_counter("remove_pusher")
-invalidate_cache_counter = metrics.register_counter("invalidate_cache")
-user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
+ "", ["stream_name"])
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
+ "")
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -69,33 +71,34 @@ class ReplicationStreamer(object):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._server_notices_sender = hs.get_server_notices_sender()
# Current connections.
self.connections = []
- metrics.register_callback("total_connections", lambda: len(self.connections))
+ LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
+ lambda: len(self.connections))
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in STREAMS_MAP.itervalues()
+ stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
- metrics.register_callback(
- "connections_per_stream",
+ LaterGauge(
+ "synapse_replication_tcp_resource_connections_per_stream", "",
+ ["stream_name"],
lambda: {
(stream_name,): len([
conn for conn in self.connections
if stream_name in conn.replication_streams
])
for stream_name in self.streams_by_name
- },
- labels=["stream_name"],
- )
+ })
self.federation_sender = None
if not hs.config.send_federation:
@@ -107,7 +110,7 @@ class ReplicationStreamer(object):
self.is_looping = False
self.pending_updates = False
- reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self):
# close all connections on shutdown
@@ -160,7 +163,11 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME, stream.last_token, stream.upto_token
)
- updates, current_token = yield stream.get_updates()
+ try:
+ updates, current_token = yield stream.get_updates()
+ except Exception:
+ logger.info("Failed to handle stream %s", stream.NAME)
+ raise
logger.debug(
"Sending %d updates to %d connections",
@@ -171,7 +178,7 @@ class ReplicationStreamer(object):
logger.info(
"Streaming: %s -> %s", stream.NAME, updates[-1][0]
)
- stream_updates_counter.inc_by(len(updates), stream.NAME)
+ stream_updates_counter.labels(stream.NAME).inc(len(updates))
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
@@ -212,11 +219,12 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
+ @defer.inlineCallbacks
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- self.presence_handler.update_external_syncs_row(
+ yield self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms,
)
@@ -240,13 +248,15 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
+ @defer.inlineCallbacks
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- self.store.insert_client_ip(
+ yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)
+ yield self._server_notices_sender.on_user_ip(user_id)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index fbafe12cc2..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from twisted.internet import defer
-from collections import namedtuple
-
import logging
+from collections import namedtuple
+from twisted.internet import defer
logger = logging.getLogger(__name__)
@@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str
"event_id", # str, optional
))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+ "group_id", # str
+ "user_id", # str
+ "type", # str
+ "content", # dict
+))
class Stream(object):
@@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs)
+class GroupServerStream(Stream):
+ NAME = "groups"
+ ROW_TYPE = GroupsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_group_stream_token
+ self.update_function = store.get_all_groups_changes
+
+ super(GroupServerStream, self).__init__(hs)
+
+
STREAMS_MAP = {
stream.NAME: stream
for stream in (
@@ -482,5 +500,6 @@ STREAMS_MAP = {
TagAccountDataStream,
AccountDataStream,
CurrentStateDeltaStream,
+ GroupServerStream,
)
}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3d809d181b..3418f06fd6 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,48 +14,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.client import (
- versions,
-)
+from six import PY3
+from synapse.http.server import JsonResource
+from synapse.rest.client import versions
from synapse.rest.client.v1 import (
- room,
+ admin,
+ directory,
events,
- profile,
- presence,
initial_sync,
- directory,
- voip,
- admin,
- pusher,
- push_rule,
- register as v1_register,
login as v1_login,
logout,
+ presence,
+ profile,
+ push_rule,
+ pusher,
+ room,
+ voip,
)
-
from synapse.rest.client.v2_alpha import (
- sync,
- filter,
account,
- register,
+ account_data,
auth,
- receipts,
- read_marker,
+ devices,
+ filter,
+ groups,
keys,
- tokenrefresh,
- tags,
- account_data,
- report_event,
- openid,
notifications,
- devices,
- thirdparty,
+ openid,
+ read_marker,
+ receipts,
+ register,
+ report_event,
sendtodevice,
+ sync,
+ tags,
+ thirdparty,
+ tokenrefresh,
user_directory,
)
-from synapse.http.server import JsonResource
+if not PY3:
+ from synapse.rest.client.v1_only import (
+ register as v1_register,
+ )
class ClientRestResource(JsonResource):
@@ -68,14 +71,22 @@ class ClientRestResource(JsonResource):
def register_servlets(client_resource, hs):
versions.register_servlets(client_resource)
- # "v1"
- room.register_servlets(hs, client_resource)
+ if not PY3:
+ # "v1" (Python 2 only)
+ v1_register.register_servlets(hs, client_resource)
+
+ # Deprecated in r0
+ initial_sync.register_servlets(hs, client_resource)
+ room.register_deprecated_servlets(hs, client_resource)
+
+ # Partially deprecated in r0
events.register_servlets(hs, client_resource)
- v1_register.register_servlets(hs, client_resource)
+
+ # "v1" + "r0"
+ room.register_servlets(hs, client_resource)
v1_login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
- initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
@@ -102,3 +113,4 @@ class ClientRestResource(JsonResource):
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
+ groups.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..00b1b3066e 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,37 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
-
-def get_transaction_key(request):
- """A helper function which returns a transaction key that can be used
- with TransactionCache for idempotent requests.
-
- Idempotency is based on the returned key being the same for separate
- requests to the same endpoint. The key is formed from the HTTP request
- path and the access_token for the requesting user.
-
- Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
- Returns:
- str: A transaction key
- """
- token = get_access_token_from_request(request)
- return request.path + "/" + token
-
-
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
- def __init__(self, clock):
- self.clock = clock
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = self.hs.get_auth()
+ self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@@ -55,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ def _get_transaction_key(self, request):
+ """A helper function which returns a transaction key that can be used
+ with TransactionCache for idempotent requests.
+
+ Idempotency is based on the returned key being the same for separate
+ requests to the same endpoint. The key is formed from the HTTP request
+ path and the access_token for the requesting user.
+
+ Args:
+ request (twisted.web.http.Request): The incoming request. Must
+ contain an access_token.
+ Returns:
+ str: A transaction key
+ """
+ token = self.auth.get_access_token_from_request(request)
+ return request.path + "/" + token
+
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -63,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
- get_transaction_key(request), fn, *args, **kwargs
+ self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
@@ -80,31 +80,30 @@ class HttpTransactionCache(object):
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
- try:
- return self.transactions[txn_key][0].observe()
- except (KeyError, IndexError):
- pass # execute the function instead.
-
- deferred = fn(*args, **kwargs)
-
- # if the request fails with a Twisted failure, remove it
- # from the transaction map. This is done to ensure that we don't
- # cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
- self.transactions.pop(txn_key, None)
- return err
- deferred.addErrback(remove_from_map)
-
- # We don't add any other errbacks to the raw deferred, so we ask
- # ObservableDeferred to swallow the error. This is fine as the error will
- # still be reported to the observers.
- observable = ObservableDeferred(deferred, consumeErrors=True)
- self.transactions[txn_key] = (observable, self.clock.time_msec())
- return observable.observe()
+ if txn_key in self.transactions:
+ observable = self.transactions[txn_key][0]
+ else:
+ # execute the function instead.
+ deferred = run_in_background(fn, *args, **kwargs)
+
+ observable = ObservableDeferred(deferred)
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+ # if the request fails with an exception, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ # we deliberately do not propagate the error any further, as we
+ # expect the observers to have reported it.
+
+ deferred.addErrback(remove_from_map)
+
+ return make_deferred_yieldable(observable.observe())
def _cleanup(self):
now = self.clock.time_msec()
- for key in self.transactions.keys():
+ for key in list(self.transactions):
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 7d786e8de3..99f6c6e3c3 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
+import logging
+
+from six.moves import http_client
+
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
from synapse.types import UserID, create_requester
-from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
logger = logging.getLogger(__name__)
@@ -55,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
+class UserRegisterServlet(ClientV1RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+ PATTERNS = client_path_patterns("/admin/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ super(UserRegisterServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return (200, {"nonce": nonce.encode('ascii')})
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(
+ 400, "nonce must be specified", errcode=Codes.BAD_JSON,
+ )
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(
+ 400, "unrecognised nonce",
+ )
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
+ else:
+ if (not isinstance(body['username'], str) or len(body['username']) > 512):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON,
+ )
+ else:
+ if (not isinstance(body['password'], str) or len(body['password']) > 512):
+ raise SynapseError(400, "Invalid password")
+
+ password = body["password"].encode("utf-8")
+ if b"\x00" in password:
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ got_mac = body["mac"]
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac.update(nonce)
+ want_mac.update(b"\x00")
+ want_mac.update(username)
+ want_mac.update(b"\x00")
+ want_mac.update(password)
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
+ want_mac = want_mac.hexdigest()
+
+ if not hmac.compare_digest(want_mac, got_mac):
+ raise SynapseError(
+ 403, "HMAC incorrect",
+ )
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+ register = RegisterRestServlet(self.hs)
+
+ (user_id, _) = yield register.registration_handler.register(
+ localpart=username.lower(), password=password, admin=bool(admin),
+ generate_token=False,
+ )
+
+ result = yield register._create_registration_details(user_id, body)
+ defer.returnValue((200, result))
+
+
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@@ -95,16 +224,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- before_ts = request.args.get("before_ts", None)
- if not before_ts:
- raise SynapseError(400, "Missing 'before_ts' arg")
-
- logger.info("before_ts: %r", before_ts[0])
-
- try:
- before_ts = int(before_ts[0])
- except Exception:
- raise SynapseError(400, "Invalid 'before_ts' arg")
+ before_ts = parse_integer(request, "before_ts", required=True)
+ logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@@ -113,12 +234,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
- "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
super(PurgeHistoryRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
@@ -128,20 +255,127 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self.handlers.message_handler.purge_history(room_id, event_id)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
- defer.returnValue((200, {}))
+ delete_local_events = bool(body.get("delete_local_events", False))
+
+ # establish the topological ordering we should keep events from. The
+ # user can provide an event_id in the URL or the request body, or can
+ # provide a timestamp in the request body.
+ if event_id is None:
+ event_id = body.get('purge_up_to_event_id')
+
+ if event_id is not None:
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ token = yield self.store.get_topological_token_for_event(event_id)
+
+ logger.info(
+ "[purge] purging up to token %s (event_id %s)",
+ token, event_id,
+ )
+ elif 'purge_up_to_ts' in body:
+ ts = body['purge_up_to_ts']
+ if not isinstance(ts, int):
+ raise SynapseError(
+ 400, "purge_up_to_ts must be an int",
+ errcode=Codes.BAD_JSON,
+ )
+
+ stream_ordering = (
+ yield self.store.find_first_stream_ordering_after_ts(ts)
+ )
+
+ r = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ )
+ if not r:
+ logger.warn(
+ "[purge] purging events not possible: No event found "
+ "(received_ts %i => stream_ordering %i)",
+ ts, stream_ordering,
+ )
+ raise SynapseError(
+ 404,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
+ )
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+ logger.info(
+ "[purge] purging up to token %s (received_ts %i => "
+ "stream_ordering %i)",
+ token, ts, stream_ordering,
+ )
+ else:
+ raise SynapseError(
+ 400,
+ "must specify purge_up_to_event_id or purge_up_to_ts",
+ errcode=Codes.BAD_JSON,
+ )
+
+ purge_id = yield self.pagination_handler.start_purge_history(
+ room_id, token,
+ delete_local_events=delete_local_events,
+ )
+
+ defer.returnValue((200, {
+ "purge_id": purge_id,
+ }))
+
+
+class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ super(PurgeHistoryStatusRestServlet, self).__init__(hs)
+ self.pagination_handler = hs.get_pagination_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, purge_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ purge_status = self.pagination_handler.get_purge_status(purge_id)
+ if purge_status is None:
+ raise NotFoundError("purge id '%s' not found" % purge_id)
+
+ defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- self.store = hs.get_datastore()
super(DeactivateAccountRestServlet, self).__init__(hs)
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
@@ -149,12 +383,9 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(target_user_id)
- yield self.store.user_delete_threepids(target_user_id)
- yield self.store.user_set_password_hash(target_user_id, None)
-
+ yield self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase,
+ )
defer.returnValue((200, {}))
@@ -168,14 +399,16 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
- " violatation will be blocked."
+ " violation will be blocked."
)
def __init__(self, hs):
super(ShutdownRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
@@ -185,17 +418,15 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
-
- new_room_user_id = content.get("new_room_user_id")
- if not new_room_user_id:
- raise SynapseError(400, "Please provide field `new_room_user_id`")
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = yield self.handlers.room_creation_handler.create_room(
+ info = yield self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@@ -208,8 +439,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
new_room_id = info["room_id"]
- msg_handler = self.handlers.message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -235,7 +465,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -244,9 +474,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
ratelimit=False
)
- yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
+ yield self.room_member_handler.forget(target_requester.user, room_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=new_room_id,
@@ -294,9 +524,30 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
@user:to_reset_password?access_token=admin_access_token
@@ -314,12 +565,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.auth_handler = hs.get_auth_handler()
+ self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -329,13 +580,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- if not new_password:
- raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))
@@ -343,7 +593,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
@@ -362,7 +612,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -379,12 +629,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
- start = request.args.get("start")[0]
- limit = request.args.get("limit")[0]
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
+ start = parse_integer(request, "start", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -395,7 +642,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token
@@ -416,12 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["limit", "start"])
limit = params['limit']
start = params['start']
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -433,7 +677,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
class SearchUsersRestServlet(ClientV1RestServlet):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/search_users/
@admin:user?access_token=admin_access_token&term=alice
@@ -453,7 +697,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have a administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -469,10 +713,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
- term = request.args.get("term")[0]
- if not term:
- raise SynapseError(400, "Missing 'term' arg")
-
+ term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(
@@ -484,6 +725,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
+ PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
@@ -492,3 +734,5 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
+ UserRegisterServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..c77d7aba68 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -16,14 +16,12 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-from synapse.http.servlet import RestServlet
-from synapse.api.urls import CLIENT_PREFIX
-from synapse.rest.client.transactions import HttpTransactionCache
-
-import re
-
import logging
+import re
+from synapse.api.urls import CLIENT_PREFIX
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
@@ -52,6 +50,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +61,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.auth = hs.get_auth()
+ self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index f15aa5c13f..69dcd618cb 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -14,17 +14,16 @@
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError, Codes
-from synapse.types import RoomAlias
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -53,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
+ room_alias = RoomAlias.from_string(room_alias)
+
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, "Missing room_id key",
+ raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
-
- room_alias = RoomAlias.from_string(room_alias)
-
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
@@ -93,7 +91,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
)
except SynapseError as e:
raise e
- except:
+ except Exception:
logger.exception("Failed to create association")
raise
except AuthError:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 701b6f549b..b70c9c2806 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -14,15 +14,15 @@
# limitations under the License.
"""This module contains REST servlets to do with event streaming, /events."""
+import logging
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
from synapse.events.utils import serialize_event
+from synapse.streams.config import PaginationConfig
-import logging
-
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 478e21eea8..fd5f85b53e 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,7 +15,9 @@
from twisted.internet import defer
+from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig
+
from .base import ClientV1RestServlet, client_path_patterns
@@ -32,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
- include_archived = request.args.get("archived", None) == ["true"]
+ include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a43410fb37..cb85fa1436 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -13,30 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+import xml.etree.ElementTree as ET
+
+from six.moves.urllib import parse as urlparse
+
+from canonicaljson import json
+from saml2 import BINDING_HTTP_POST, config
+from saml2.client import Saml2Client
+
from twisted.internet import defer
+from twisted.web.client import PartialDownloadError
-from synapse.api.errors import SynapseError, LoginError, Codes
-from synapse.types import UserID
+from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
-import simplejson as json
-import urllib
-import urlparse
-
-import logging
-from saml2 import BINDING_HTTP_POST
-from saml2 import config
-from saml2.client import Saml2Client
-
-import xml.etree.ElementTree as ET
-
-from twisted.web.client import PartialDownloadError
-
-
logger = logging.getLogger(__name__)
@@ -85,7 +82,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@@ -94,7 +90,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@@ -121,8 +116,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+ flows.extend((
+ {"type": t} for t in self.auth_handler.get_supported_login_types()
+ ))
return (200, {"flows": flows})
@@ -133,14 +130,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
-
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
+ if self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@@ -157,15 +148,31 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
- raise SynapseError(400, "Bad login type.")
+ result = yield self._do_other_login(login_submission)
+ defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if "password" not in login_submission:
- raise SynapseError(400, "Missing parameter: password")
+ def _do_other_login(self, login_submission):
+ """Handle non-token/saml/jwt logins
+
+ Args:
+ login_submission:
+ Returns:
+ (int, object): HTTP code/response
+ """
+ # Log the request we got, but only certain fields to minimise the chance of
+ # logging someone's password (even if they accidentally put it in the wrong
+ # field)
+ logger.info(
+ "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
+ login_submission.get('identifier'),
+ login_submission.get('medium'),
+ login_submission.get('address'),
+ login_submission.get('user'),
+ )
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@@ -181,19 +188,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- if 'medium' not in identifier or 'address' not in identifier:
+ address = identifier.get('address')
+ medium = identifier.get('medium')
+
+ if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- address = identifier['address']
- if identifier['medium'] == 'email':
+ if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- identifier['medium'], address
+ medium, address,
)
if not user_id:
+ logger.warn(
+ "unknown 3pid identifier medium %s, address %r",
+ medium, address,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
@@ -208,30 +221,29 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- user_id = identifier["user"]
-
- if not user_id.startswith('@'):
- user_id = UserID.create(
- user_id, self.hs.hostname
- ).to_string()
-
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_password_login(
- user_id=user_id,
- password=login_submission["password"],
+ canonical_user_id, callback = yield auth_handler.validate_login(
+ identifier["user"],
+ login_submission,
+ )
+
+ device_id = yield self._register_device(
+ canonical_user_id, login_submission,
)
- device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name"),
+ canonical_user_id, device_id,
)
+
result = {
- "user_id": user_id, # may have changed
+ "user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
+ if callback is not None:
+ yield callback(result)
+
defer.returnValue((200, result))
@defer.inlineCallbacks
@@ -244,7 +256,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
@@ -278,7 +289,7 @@ class LoginRestServlet(ClientV1RestServlet):
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
@@ -287,7 +298,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
@@ -444,7 +454,7 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -30,15 +29,33 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
- access_token = get_access_token_from_request(request)
- yield self.store.delete_access_token(access_token)
+ try:
+ requester = yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # this implies the access token has already been deleted.
+ defer.returnValue((401, {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired"
+ }))
+ else:
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = self._auth.get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
+ else:
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
+
defer.returnValue((200, {}))
@@ -47,8 +64,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,7 +75,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.store.user_delete_access_tokens(user_id)
+
+ # first delete all of the user's devices
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+
+ # .. and then delete any access tokens which weren't associated with
+ # devices.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 47b2dc45e7..a14f0c807e 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -15,15 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
+import logging
+
+from six import string_types
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, AuthError
-from synapse.types import UserID
+from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.types import UserID
-import logging
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@@ -71,14 +74,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], basestring):
+ if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.")
if content:
raise KeyError()
except SynapseError as e:
raise e
- except:
+ except Exception:
raise SynapseError(400, "Unable to parse state")
yield self.presence_handler.set_state(user, state)
@@ -129,7 +132,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content:
for u in content["invite"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
@@ -140,7 +143,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content:
for u in content["drop"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1a5045c9ec..a23edd8fe5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -16,9 +16,10 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import UserID
+
+from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@@ -26,13 +27,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
@@ -52,10 +53,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
try:
new_name = content["displayname"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -69,13 +70,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
@@ -94,10 +95,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
try:
new_name = content["avatar_url"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_avatar_url(
+ yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -111,16 +112,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 6bb4821ec6..6e95d9bec2 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -16,16 +16,18 @@
from twisted.internet import defer
from synapse.api.errors import (
- SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
+ NotFoundError,
+ StoreError,
+ SynapseError,
+ UnrecognizedRequestError,
)
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.storage.push_rule import (
- InconsistentRuleException, RuleNotFoundException
-)
-from synapse.push.clientformat import format_push_rules_for_user
+from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
-from synapse.http.servlet import parse_json_value_from_request
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+
+from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(ClientV1RestServlet):
@@ -73,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- before = request.args.get("before", None)
+ before = parse_string(request, "before")
if before:
- before = _namespaced_rule_id(spec, before[0])
+ before = _namespaced_rule_id(spec, before)
- after = request.args.get("after", None)
+ after = parse_string(request, "after")
if after:
- after = _namespaced_rule_id(spec, after[0])
+ after = _namespaced_rule_id(spec, after)
try:
yield self.store.add_push_rule(
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9a2ed6ed88..182a68b1e2 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, Codes
-from synapse.push import PusherConfigException
+from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.http.server import finish_request
from synapse.http.servlet import (
- parse_json_object_from_request, parse_string, RestServlet
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+ parse_string,
)
-from synapse.http.server import finish_request
-from synapse.api.errors import StoreError
+from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
logger = logging.getLogger(__name__)
@@ -73,6 +75,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -81,25 +84,19 @@ class PushersSetRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- pusher_pool = self.hs.get_pusherpool()
-
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
content['app_id'], content['pushkey'], user_id=user.to_string()
)
defer.returnValue((200, {}))
- reqd = ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
- missing = []
- for i in reqd:
- if i not in content:
- missing.append(i)
- if len(missing):
- raise SynapseError(400, "Missing parameters: " + ','.join(missing),
- errcode=Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ content,
+ ['kind', 'app_id', 'app_display_name',
+ 'device_display_name', 'pushkey', 'lang', 'data']
+ )
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@@ -109,14 +106,14 @@ class PushersSetRestServlet(ClientV1RestServlet):
append = content['append']
if not append:
- yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try:
- yield pusher_pool.add_pusher(
+ yield self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content['kind'],
@@ -148,10 +145,11 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(RestServlet, self).__init__()
+ super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -161,10 +159,8 @@ class PushersRemoveRestServlet(RestServlet):
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
- pusher_pool = self.hs.get_pusherpool()
-
try:
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
@@ -178,7 +174,6 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cd388770c8..b7bd878c90 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +15,28 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+import logging
+
+from six.moves.urllib import parse as urlparse
+
+from canonicaljson import json
+
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.api.errors import SynapseError, Codes, AuthError
-from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
-from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
-from synapse.events.utils import serialize_event, format_event_for_client_v2
+from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
- parse_json_object_from_request, parse_string, parse_integer
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
)
+from synapse.streams.config import PaginationConfig
+from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID
-import logging
-import urllib
-import ujson as json
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@@ -39,7 +46,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self._room_creation_handler = hs.get_room_creation_handler()
def register(self, http_server):
PATTERNS = "/createRoom"
@@ -62,8 +69,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.room_creation_handler
- info = yield handler.create_room(
+ info = yield self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -82,6 +88,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.message_handler = hs.get_message_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -116,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
- msg_handler = self.handlers.message_handler
+ msg_handler = self.message_handler
data = yield msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
@@ -154,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.handlers.room_member_handler.update_membership(
+ event = yield self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -162,16 +171,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- msg_handler = self.handlers.message_handler
- event, context = yield msg_handler.create_event(
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
event_dict,
- token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield msg_handler.send_nonmember_event(requester, event, context)
-
ret = {}
if event:
ret = {"event_id": event.event_id}
@@ -183,7 +188,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +200,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -222,7 +231,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -238,7 +247,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
@@ -247,10 +256,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
- except:
+ except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
- handler = self.handlers.room_member_handler
+ handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
@@ -259,7 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier,
))
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -369,14 +378,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.message_handler
- events = yield handler.get_state_events(
+ events = yield self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
)
@@ -398,22 +406,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
- self.state = hs.get_state_handler()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.message_handler.get_joined_members(
+ requester, room_id,
+ )
defer.returnValue((200, {
- "joined": {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in users_with_profile.iteritems()
- }
+ "joined": users_with_profile,
}))
@@ -423,7 +427,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.pagination_handler = hs.get_pagination_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -432,14 +436,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
- filter_bytes = request.args.get("filter", None)
+ filter_bytes = parse_string(request, "filter")
if filter_bytes:
- filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
- handler = self.handlers.message_handler
- msgs = yield handler.get_messages(
+ msgs = yield self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -456,14 +459,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- handler = self.handlers.message_handler
# Get all the current state for this room
- events = yield handler.get_state_events(
+ events = yield self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
@@ -491,23 +493,45 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(RoomEventServlet, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ event = yield self.event_handler.get_event(requester.user, event_id)
+
+ time_now = self.clock.time_msec()
+ if event:
+ defer.returnValue((200, serialize_event(event, time_now)))
+ else:
+ defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
- super(RoomEventContext, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
- self.handlers = hs.get_handlers()
+ self.room_context_handler = hs.get_room_context_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- limit = int(request.args.get("limit", [10])[0])
+ limit = parse_integer(request, "limit", default=10)
- results = yield self.handlers.room_context_handler.get_event_context(
+ results = yield self.room_context_handler.get_event_context(
requester.user,
room_id,
event_id,
@@ -537,7 +561,7 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -550,7 +574,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False,
)
- yield self.handlers.room_member_handler.forget(
+ yield self.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
@@ -568,12 +592,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@@ -591,13 +615,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.handlers.room_member_handler.do_3pid_invite(
+ yield self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -611,15 +635,14 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
- if "user_id" not in content:
- raise SynapseError(400, "Missing user_id key.")
+ assert_params_in_dict(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']}
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -629,7 +652,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content=event_content,
)
- defer.returnValue((200, {}))
+ return_value = {}
+
+ if membership_action == "join":
+ return_value["room_id"] = room_id
+
+ defer.returnValue((200, return_value))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -647,6 +675,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -657,8 +686,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -692,8 +720,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
- room_id = urllib.unquote(room_id)
- target_user = UserID.from_string(urllib.unquote(user_id))
+ room_id = urlparse.unquote(room_id)
+ target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
@@ -734,7 +762,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- batch = request.args.get("next_batch", [None])[0]
+ batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,
@@ -802,9 +830,13 @@ def register_servlets(hs, http_server):
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
- RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventContext(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
+
+
+def register_deprecated_servlets(hs, http_server):
+ RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index c43b30b73a..62f4c3d93e 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
+import hashlib
+import hmac
+
from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns
-import hmac
-import hashlib
-import base64
-
-
class VoipRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$")
diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py
new file mode 100644
index 0000000000..936f902ace
--- /dev/null
+++ b/synapse/rest/client/v1_only/__init__.py
@@ -0,0 +1,3 @@
+"""
+REST APIs that are only used in v1 (the legacy API).
+"""
diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py
new file mode 100644
index 0000000000..9d4db7437c
--- /dev/null
+++ b/synapse/rest/client/v1_only/base.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains base REST classes for constructing client v1 servlets.
+"""
+
+import re
+
+from synapse.api.urls import CLIENT_PREFIX
+
+
+def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
+ """Creates a regex compiled client path with the correct client path
+ prefix.
+
+ Args:
+ path_regex (str): The regex string to match. This should NOT have a ^
+ as this will be prefixed.
+ Returns:
+ list of SRE_Pattern
+ """
+ patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
+ if include_in_unstable:
+ unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
+ patterns.append(re.compile("^" + unstable_prefix + path_regex))
+ return patterns
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1_only/register.py
index ecf7e311a9..3439c3c6d4 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1_only/register.py
@@ -14,21 +14,20 @@
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
+import hmac
+import logging
+from hashlib import sha1
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, Codes
-from synapse.api.constants import LoginType
-from synapse.api.auth import get_access_token_from_request
-from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
+from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester
-from synapse.util.async import run_on_reactor
-
-from hashlib import sha1
-import hmac
-import logging
+from .base import v1_only_client_path_patterns
logger = logging.getLogger(__name__)
@@ -51,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions.
"""
- PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
+ PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
def __init__(self, hs):
"""
@@ -66,14 +65,20 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
+ self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
@@ -111,8 +123,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
- if "type" not in register_json:
- raise SynapseError(400, "Missing 'type' key.")
+ assert_params_in_dict(register_json, ["type"])
try:
login_type = register_json["type"]
@@ -258,7 +269,6 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
- yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
@@ -298,11 +308,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- as_token = get_access_token_from_request(request)
-
- if "user" not in register_json:
- raise SynapseError(400, "Expected 'user' key.")
+ as_token = self.auth.get_access_token_from_request(request)
+ assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -319,14 +327,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
- yield run_on_reactor()
-
- if not isinstance(register_json.get("mac", None), basestring):
- raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), basestring):
- raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), basestring):
- raise SynapseError(400, "Expected 'password' key.")
+ assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -336,9 +337,9 @@ class RegisterRestServlet(ClientV1RestServlet):
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
- if "\x00" in user:
+ if b"\x00" in user:
raise SynapseError(400, "Invalid user")
- if "\x00" in password:
+ if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
@@ -346,20 +347,20 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"])
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
- want_mac.update("\x00")
+ want_mac.update(b"\x00")
want_mac.update(password)
- want_mac.update("\x00")
- want_mac.update("admin" if admin else "notadmin")
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
- localpart=user,
+ localpart=user.lower(),
password=password,
admin=bool(admin),
)
@@ -379,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
"""
- PATTERNS = client_path_patterns("/createUser$", releases=())
+ PATTERNS = v1_only_client_path_patterns("/createUser$")
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
@@ -390,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
@@ -409,13 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
- yield run_on_reactor()
-
- if "localpart" not in user_json:
- raise SynapseError(400, "Expected 'localpart' key.")
-
- if "displayname" not in user_json:
- raise SynapseError(400, "Expected 'displayname' key.")
+ assert_params_in_dict(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
import re
-import logging
+from twisted.internet import defer
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+ """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+ Takes a on_POST method which returns a deferred (errcode, body) response
+ and adds exception handling to turn a InteractiveAuthIncompleteError into
+ a 401 response.
+
+ Normal usage is:
+
+ @interactive_auth_handler
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ # ...
+ yield self.auth_handler.check_auth
+ """
+ def wrapped(*args, **kwargs):
+ res = defer.maybeDeferred(orig, *args, **kwargs)
+ res.addErrback(_catch_incomplete_interactive_auth)
+ return res
+ return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+ """helper for interactive_auth_handler
+
+ Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+ Args:
+ f (failure.Failure):
+ """
+ f.trap(InteractiveAuthIncompleteError)
+ return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 4990b22b9f..eeae466d82 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six.moves import http_client
from twisted.internet import defer
from synapse.api.constants import LoginType
-from synapse.api.errors import LoginError, SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, assert_params_in_request
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
)
-from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns
-
-import logging
-
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -44,10 +47,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -72,13 +80,18 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -99,56 +112,60 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
+ self._set_password_handler = hs.get_set_password_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- [LoginType.EMAIL_IDENTITY],
- [LoginType.MSISDN],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
+ # there are two possibilities here. Either the user does not have an
+ # access token, and needs to do a password reset; or they have one and
+ # need to validate their identity.
+ #
+ # In the first case, we offer a couple of means of identifying
+ # themselves (email and msisdn, though it's unclear if msisdn actually
+ # works).
+ #
+ # In the second case, we require a password to confirm their identity.
+
+ if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- elif LoginType.EMAIL_IDENTITY in result:
- threepid = result[LoginType.EMAIL_IDENTITY]
- if 'medium' not in threepid or 'address' not in threepid:
- raise SynapseError(500, "Malformed threepid")
- if threepid['medium'] == 'email':
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- threepid['address'] = threepid['address'].lower()
- # if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.datastore.get_user_id_by_threepid(
- threepid['medium'], threepid['address']
+ params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
)
- if not threepid_user_id:
- raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
- user_id = threepid_user_id
+ user_id = requester.user.to_string()
else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = None
+ result, params, _ = yield self.auth_handler.check_auth(
+ [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+ body, self.hs.get_ip_from_request(request),
+ )
- if 'new_password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
+ if LoginType.EMAIL_IDENTITY in result:
+ threepid = result[LoginType.EMAIL_IDENTITY]
+ if 'medium' not in threepid or 'address' not in threepid:
+ raise SynapseError(500, "Malformed threepid")
+ if threepid['medium'] == 'email':
+ # For emails, transform the address to lowercase.
+ # We store all email addreses as lowercase in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ threepid['address'] = threepid['address'].lower()
+ # if using email, we must know about the email they're authing with!
+ threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+ threepid['medium'], threepid['address']
+ )
+ if not threepid_user_id:
+ raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+ user_id = threepid_user_id
+ else:
+ logger.error("Auth succeeded but no known type! %r", result.keys())
+ raise SynapseError(500, "", Codes.UNKNOWN)
+
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
user_id, new_password, requester
)
@@ -162,42 +179,39 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
+ super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- super(DeactivateAccountRestServlet, self).__init__()
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
- requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = yield self.auth.get_user_by_req(request)
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(user_id)
- yield self.store.user_delete_threepids(user_id)
- yield self.store.user_set_password_hash(user_id, None)
+ # allow ASes to dectivate their own users
+ if requester.app_service:
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase,
+ )
+ defer.returnValue((200, {}))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase,
+ )
defer.returnValue((200, {}))
@@ -213,15 +227,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ body,
+ ['id_server', 'client_secret', 'email', 'send_attempt'],
+ )
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
@@ -246,21 +260,18 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
- ]
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ ])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -285,8 +296,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- yield run_on_reactor()
-
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids(
@@ -297,8 +306,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds')
@@ -350,29 +357,40 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
-
- required = ['medium', 'address']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address']
- )
+ try:
+ yield self.auth_handler.delete_threepid(
+ user_id, body['medium'], body['address']
+ )
+ except Exception:
+ # NB. This endpoint should succeed if there is nothing to
+ # delete, so it should only throw if something is wrong
+ # that we ought to care about.
+ logger.exception("Failed to remove threepid")
+ raise SynapseError(500, "Failed to remove threepid")
defer.returnValue((200, {}))
+class WhoamiRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/whoami$")
+
+ def __init__(self, hs):
+ super(WhoamiRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
+ defer.returnValue((200, {'user_id': requester.user.to_string()}))
+
+
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
@@ -382,3 +400,4 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 0e0a187efd..371e9aa354 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import client_v2_patterns
-
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError, SynapseError
+import logging
from twisted.internet import defer
-import logging
+from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e5577148f..bd8b5f4afa 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.constants import LoginType
@@ -23,9 +25,6 @@ from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = """
@@ -129,7 +128,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
@@ -175,7 +173,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index b57ba95d24..9b75bb1377 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,15 +17,20 @@ import logging
from twisted.internet import defer
-from synapse.api import constants, errors
-from synapse.http import servlet
-from ._base import client_v2_patterns
+from synapse.api import errors
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
-class DevicesRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+class DevicesRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -46,12 +51,12 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
-class DeleteDevicesRestServlet(servlet.RestServlet):
+class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -60,31 +65,28 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # deal with older clients which didn't pass a J*DELETESON dict
+ # DELETE
+ # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
- if 'devices' not in body:
- raise errors.SynapseError(
- 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
- )
+ assert_params_in_dict(body, ["devices"])
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [constants.LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
- requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
body['devices'],
@@ -92,9 +94,8 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
-class DeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
- releases=[], v2_alpha=False)
+class DeviceRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -116,10 +117,13 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@@ -129,17 +133,12 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [constants.LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
- requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
- requester.user.to_string(),
- device_id,
+ requester.user.to_string(), device_id,
)
defer.returnValue((200, {}))
@@ -147,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index d2b2fd66e6..ae86728879 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
+from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
-from ._base import client_v2_patterns
-from ._base import set_timeline_upper_limit
-
-import logging
-
+from ._base import client_v2_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -50,7 +48,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_id = int(filter_id)
- except:
+ except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
new file mode 100644
index 0000000000..21e02c07c0
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import GroupID
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class GroupServlet(RestServlet):
+ """Get the group profile
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+
+ def __init__(self, hs):
+ super(GroupServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ group_description = yield self.groups_handler.get_group_profile(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, group_description))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ yield self.groups_handler.update_group_profile(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class GroupSummaryServlet(RestServlet):
+ """Get the full group summary
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+
+ def __init__(self, hs):
+ super(GroupSummaryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ get_group_summary = yield self.groups_handler.get_group_summary(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, get_group_summary))
+
+
+class GroupSummaryRoomsCatServlet(RestServlet):
+ """Update/delete a rooms entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryRoomsCatServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoryServlet(RestServlet):
+ """Get/add/update/delete a group category
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoriesServlet(RestServlet):
+ """Get all group categories
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoriesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupRoleServlet(RestServlet):
+ """Get/add/update/delete a group role
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRolesServlet(RestServlet):
+ """Get all group roles
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRolesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupSummaryUsersRoleServlet(RestServlet):
+ """Update/delete a user's entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/users/:room_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryUsersRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRoomServlet(RestServlet):
+ """Get all rooms in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+
+ def __init__(self, hs):
+ super(GroupRoomServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupUsersServlet(RestServlet):
+ """Get all users in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+
+ def __init__(self, hs):
+ super(GroupUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupInvitedUsersServlet(RestServlet):
+ """Get users invited to a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+
+ def __init__(self, hs):
+ super(GroupInvitedUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_invited_users_in_group(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSettingJoinPolicyServlet(RestServlet):
+ """Set group join policy
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+
+ def __init__(self, hs):
+ super(GroupSettingJoinPolicyServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+
+ result = yield self.groups_handler.set_group_join_policy(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupCreateServlet(RestServlet):
+ """Create a group
+ """
+ PATTERNS = client_v2_patterns("/create_group$")
+
+ def __init__(self, hs):
+ super(GroupCreateServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.server_name = hs.hostname
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ # TODO: Create group on remote server
+ content = parse_json_object_from_request(request)
+ localpart = content.pop("localpart")
+ group_id = GroupID(localpart, self.server_name).to_string()
+
+ result = yield self.groups_handler.create_group(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsServlet(RestServlet):
+ """Add a room to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsConfigServlet(RestServlet):
+ """Update the config of a room in a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsConfigServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id, config_key):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersInviteServlet(RestServlet):
+ """Invite a user to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.store = hs.get_datastore()
+ self.is_mine_id = hs.is_mine_id
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ config = content.get("config", {})
+ result = yield self.groups_handler.invite(
+ group_id, user_id, requester_user_id, config,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersKickServlet(RestServlet):
+ """Kick a user from the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersKickServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfLeaveServlet(RestServlet):
+ """Leave a joined group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/leave$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfLeaveServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, requester_user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfJoinServlet(RestServlet):
+ """Attempt to join a group, or knock
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/join$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfJoinServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.join_group(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfAcceptInviteServlet(RestServlet):
+ """Accept a group invite
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfAcceptInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.accept_invite(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfUpdatePublicityServlet(RestServlet):
+ """Update whether we publicise a users membership of a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfUpdatePublicityServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ publicise = content["publicise"]
+ yield self.store.update_group_publicity(
+ group_id, requester_user_id, publicise,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class PublicisedGroupsForUserServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ result = yield self.groups_handler.get_publicised_groups_for_user(
+ user_id
+ )
+
+ defer.returnValue((200, result))
+
+
+class PublicisedGroupsForUsersServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ content = parse_json_object_from_request(request)
+ user_ids = content["user_ids"]
+
+ result = yield self.groups_handler.bulk_get_publicised_groups(
+ user_ids
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupsForUserServlet(RestServlet):
+ """Get all groups the logged in user is joined to
+ """
+ PATTERNS = client_v2_patterns(
+ "/joined_groups$"
+ )
+
+ def __init__(self, hs):
+ super(GroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_joined_groups(requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+ GroupServlet(hs).register(http_server)
+ GroupSummaryServlet(hs).register(http_server)
+ GroupInvitedUsersServlet(hs).register(http_server)
+ GroupUsersServlet(hs).register(http_server)
+ GroupRoomServlet(hs).register(http_server)
+ GroupSettingJoinPolicyServlet(hs).register(http_server)
+ GroupCreateServlet(hs).register(http_server)
+ GroupAdminRoomsServlet(hs).register(http_server)
+ GroupAdminRoomsConfigServlet(hs).register(http_server)
+ GroupAdminUsersInviteServlet(hs).register(http_server)
+ GroupAdminUsersKickServlet(hs).register(http_server)
+ GroupSelfLeaveServlet(hs).register(http_server)
+ GroupSelfJoinServlet(hs).register(http_server)
+ GroupSelfAcceptInviteServlet(hs).register(http_server)
+ GroupsForUserServlet(hs).register(http_server)
+ GroupCategoryServlet(hs).register(http_server)
+ GroupCategoriesServlet(hs).register(http_server)
+ GroupSummaryRoomsCatServlet(hs).register(http_server)
+ GroupRoleServlet(hs).register(http_server)
+ GroupRolesServlet(hs).register(http_server)
+ GroupSelfUpdatePublicityServlet(hs).register(http_server)
+ GroupSummaryUsersRoleServlet(hs).register(http_server)
+ PublicisedGroupsForUserServlet(hs).register(http_server)
+ PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 6a3cfe84f8..8486086b51 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,10 +19,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, parse_integer
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
)
-from synapse.http.servlet import parse_string
from synapse.types import StreamToken
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -53,8 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -128,10 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/query$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -160,10 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns(
- "/keys/changes$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -188,13 +184,11 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- changed = yield self.device_handler.get_user_ids_changed(
+ results = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
- defer.returnValue((200, {
- "changed": list(changed),
- }))
+ defer.returnValue((200, results))
class OneTimeKeyServlet(RestServlet):
@@ -215,10 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/claim$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..2a6ea3df5f 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -13,24 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.http.servlet import (
- RestServlet, parse_string, parse_integer
-)
from synapse.events.utils import (
- serialize_event, format_event_for_client_v2_without_room_id,
+ format_event_for_client_v2_without_room_id,
+ serialize_event,
)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns
-import logging
-
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$", releases=())
+ PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
@@ -88,7 +87,7 @@ class NotificationsServlet(RestServlet):
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
- next_token = pa["stream_ordering"]
+ next_token = str(pa["stream_ordering"])
defer.returnValue((200, {
"notifications": returned_push_actions,
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index aa1cae8e1e..01c90aa2a3 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -14,15 +14,15 @@
# limitations under the License.
-from ._base import client_v2_patterns
+import logging
+
+from twisted.internet import defer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
-from twisted.internet import defer
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 2f8784fe06..a6e582a5ae 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 1fbff2edd8..de370cac45 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1421c18152..d6cf915d86 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -14,25 +14,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hmac
+import logging
+from hashlib import sha1
+
+from six import string_types
+
from twisted.internet import defer
import synapse
-from synapse.api.auth import get_access_token_from_request, has_access_token
+import synapse.types
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
+from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+ parse_string,
)
from synapse.util.msisdn import phone_number_to_msisdn
-
-from ._base import client_v2_patterns
-
-import logging
-import hmac
-from hashlib import sha1
-from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.threepids import check_3pid_allowed
+from ._base import client_v2_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -64,10 +68,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -95,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
@@ -103,6 +112,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -170,13 +184,13 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
+ self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
kind = "user"
@@ -196,20 +210,20 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
- if (not isinstance(body['password'], basestring) or
+ if (not isinstance(body['password'], string_types) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
- if (not isinstance(body['username'], basestring) or
+ if (not isinstance(body['username'], string_types) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
appservice = None
- if has_access_token(request):
+ if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -221,15 +235,30 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
- access_token = get_access_token_from_request(request)
- if isinstance(desired_username, basestring):
+ # XXX we should check that desired_username is valid. Currently
+ # we give appservices carte blanche for any insanity in mxids,
+ # because the IRC bridges rely on being able to register stupid
+ # IDs.
+
+ access_token = self.auth.get_access_token_from_request(request)
+
+ if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
+ # for either shared secret or regular registration, downcase the
+ # provided username before attempting to register it. This should mean
+ # that people who try to register with upper-case in their usernames
+ # don't get a nasty surprise. (Note that we treat username
+ # case-insenstively in login, so they are free to carry on imagining
+ # that their username is CrAzYh4cKeR if that keeps them happy)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
@@ -286,34 +315,66 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN],
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
- authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+ auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
- if not authed:
- defer.returnValue((401, auth_result))
- return
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
if registered_user_id is not None:
logger.info(
@@ -325,14 +386,15 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
- if 'password' not in params:
- raise SynapseError(400, "Missing password.",
- Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
@@ -383,15 +445,24 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
+ if not username:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
- user = username.encode("utf-8")
+ # use the username from the original request rather than the
+ # downcased one in `username` for the mac calculation
+ user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
+ # FIXME this is different to the /v1/register endpoint, which
+ # includes the password and admin flag in the hashed text. Why are
+ # these different?
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
@@ -492,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
- if any(x not in threepid for x in reqd):
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ try:
+ assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ except SynapseError as ex:
+ if ex.errcode == Codes.MISSING_PARAM:
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue(None)
+ raise
yield self.auth_handler.add_threepid(
user_id,
@@ -523,25 +597,28 @@ class RegisterRestServlet(RestServlet):
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
- device_id and initial_device_name
+ device_id, initial_device_name and inhibit_login
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- device_id = yield self._register_device(user_id, params)
+ result = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ }
+ if not params.get("inhibit_login", False):
+ device_id = yield self._register_device(user_id, params)
- access_token = (
- yield self.auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
- initial_display_name=params.get("initial_device_display_name")
+ access_token = (
+ yield self.auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id,
+ )
)
- )
- defer.returnValue({
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- })
+ result.update({
+ "access_token": access_token,
+ "device_id": device_id,
+ })
+ defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
@@ -566,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
- defer.returnValue((403, "Guest access is disabled"))
+ raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 8903e12405..95d2a71ec2 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -13,13 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from six import string_types
+from six.moves import http_client
-import logging
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ("reason", "score"))
+
+ if not isinstance(body["reason"], string_types):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'reason' must be a string",
+ Codes.BAD_JSON,
+ )
+ if not isinstance(body["score"], int):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'score' must be an integer",
+ Codes.BAD_JSON,
+ )
yield self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
- reason=body.get("reason"),
+ reason=body["reason"],
content=body,
received_ts=self.clock.time_msec(),
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..a9e9a47a0b 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- releases=[], v2_alpha=False
+ v2_alpha=False
)
def __init__(self, hs):
@@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 83e209d18f..8aa06faf23 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -13,27 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
-from synapse.http.servlet import (
- RestServlet, parse_string, parse_integer, parse_boolean
+from synapse.api.constants import PresenceState
+from synapse.api.errors import SynapseError
+from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.events.utils import (
+ format_event_for_client_v2_without_room_id,
+ serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
-from synapse.events.utils import (
- serialize_event, format_event_for_client_v2_without_room_id,
-)
-from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
-from synapse.api.errors import SynapseError
-from synapse.api.constants import PresenceState
-from ._base import client_v2_patterns
-from ._base import set_timeline_upper_limit
-
-import itertools
-import logging
-import ujson as json
+from ._base import client_v2_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -85,6 +84,7 @@ class SyncRestServlet(RestServlet):
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -110,7 +110,7 @@ class SyncRestServlet(RestServlet):
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
- logger.info(
+ logger.debug(
"/sync: user=%r, timeout=%r, since=%r,"
" set_presence=%r, filter_id=%r, device_id=%r" % (
user, timeout, since, set_presence, filter_id, device_id
@@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet):
filter_object = json.loads(filter_id)
set_timeline_upper_limit(filter_object,
self.hs.config.filter_timeline_limit)
- except:
+ except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
@@ -149,6 +149,9 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(user.to_string())
+
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
@@ -164,27 +167,35 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
+ response_content = self.encode_response(
+ time_now, sync_result, requester.access_token_id, filter
+ )
+
+ defer.returnValue((200, response_content))
- joined = self.encode_joined(
- sync_result.joined, time_now, requester.access_token_id, filter.event_fields
+ @staticmethod
+ def encode_response(time_now, sync_result, access_token_id, filter):
+ joined = SyncRestServlet.encode_joined(
+ sync_result.joined, time_now, access_token_id, filter.event_fields
)
- invited = self.encode_invited(
- sync_result.invited, time_now, requester.access_token_id
+ invited = SyncRestServlet.encode_invited(
+ sync_result.invited, time_now, access_token_id,
)
- archived = self.encode_archived(
- sync_result.archived, time_now, requester.access_token_id,
+ archived = SyncRestServlet.encode_archived(
+ sync_result.archived, time_now, access_token_id,
filter.event_fields,
)
- response_content = {
+ return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
- "changed": list(sync_result.device_lists),
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
},
- "presence": self.encode_presence(
+ "presence": SyncRestServlet.encode_presence(
sync_result.presence, time_now
),
"rooms": {
@@ -192,13 +203,17 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
- defer.returnValue((200, response_content))
-
- def encode_presence(self, events, time_now):
+ @staticmethod
+ def encode_presence(events, time_now):
return {
"events": [
{
@@ -212,7 +227,8 @@ class SyncRestServlet(RestServlet):
]
}
- def encode_joined(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_joined(rooms, time_now, token_id, event_fields):
"""
Encode the joined rooms in a sync result
@@ -231,13 +247,14 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields
)
return joined
- def encode_invited(self, rooms, time_now, token_id):
+ @staticmethod
+ def encode_invited(rooms, time_now, token_id):
"""
Encode the invited rooms in a sync result
@@ -270,7 +287,8 @@ class SyncRestServlet(RestServlet):
return invited
- def encode_archived(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_archived(rooms, time_now, token_id, event_fields):
"""
Encode the archived rooms in a sync result
@@ -289,7 +307,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields
)
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index dac8603b07..4fea614e95 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import client_v2_patterns
-
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError
+import logging
from twisted.internet import defer
-import logging
+from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 6fceb23e26..d9d379182e 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -20,13 +20,14 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -43,8 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,8 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -90,8 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 6e012da4aa..cac0624ba7 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -65,7 +66,7 @@ class UserDirectorySearchRestServlet(RestServlet):
try:
search_term = body["search_term"]
- except:
+ except Exception:
raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index e984ea47db..6ac2987b98 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.servlet import RestServlet
-
import logging
import re
+from synapse.http.servlet import RestServlet
+
logger = logging.getLogger(__name__)
@@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
"r0.0.1",
"r0.1.0",
"r0.2.0",
+ "r0.3.0",
]
})
diff --git a/synapse/rest/consent/__init__.py b/synapse/rest/consent/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rest/consent/__init__.py
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
new file mode 100644
index 0000000000..147ff7d79b
--- /dev/null
+++ b/synapse/rest/consent/consent_resource.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hmac
+import logging
+from hashlib import sha256
+from os import path
+
+from six.moves import http_client
+
+import jinja2
+from jinja2 import TemplateNotFound
+
+from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.api.errors import NotFoundError, StoreError, SynapseError
+from synapse.config import ConfigError
+from synapse.http.server import finish_request, wrap_html_request_handler
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+
+# language to use for the templates. TODO: figure this out from Accept-Language
+TEMPLATE_LANGUAGE = "en"
+
+logger = logging.getLogger(__name__)
+
+# use hmac.compare_digest if we have it (python 2.7.7), else just use equality
+if hasattr(hmac, "compare_digest"):
+ compare_digest = hmac.compare_digest
+else:
+ def compare_digest(a, b):
+ return a == b
+
+
+class ConsentResource(Resource):
+ """A twisted Resource to display a privacy policy and gather consent to it
+
+ When accessed via GET, returns the privacy policy via a template.
+
+ When accessed via POST, records the user's consent in the database and
+ displays a success page.
+
+ The config should include a template_dir setting which contains templates
+ for the HTML. The directory should contain one subdirectory per language
+ (eg, 'en', 'fr'), and each language directory should contain the policy
+ document (named as '<version>.html') and a success page (success.html).
+
+ Both forms take a set of parameters from the browser. For the POST form,
+ these are normally sent as form parameters (but may be query-params); for
+ GET requests they must be query params. These are:
+
+ u: the complete mxid, or the localpart of the user giving their
+ consent. Required for both GET (where it is used as an input to the
+ template) and for POST (where it is used to find the row in the db
+ to update).
+
+ h: hmac_sha256(secret, u), where 'secret' is the privacy_secret in the
+ config file. If it doesn't match, the request is 403ed.
+
+ v: the version of the privacy policy being agreed to.
+
+ For GET: optional, and defaults to whatever was set in the config
+ file. Used to choose the version of the policy to pick from the
+ templates directory.
+
+ For POST: required; gives the value to be recorded in the database
+ against the user.
+ """
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ """
+ Resource.__init__(self)
+
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ # this is required by the request_handler wrapper
+ self.clock = hs.get_clock()
+
+ self._default_consent_version = hs.config.user_consent_version
+ if self._default_consent_version is None:
+ raise ConfigError(
+ "Consent resource is enabled but user_consent section is "
+ "missing in config file.",
+ )
+
+ # daemonize changes the cwd to /, so make the path absolute now.
+ consent_template_directory = path.abspath(
+ hs.config.user_consent_template_dir,
+ )
+ if not path.isdir(consent_template_directory):
+ raise ConfigError(
+ "Could not find template directory '%s'" % (
+ consent_template_directory,
+ ),
+ )
+
+ loader = jinja2.FileSystemLoader(consent_template_directory)
+ self._jinja_env = jinja2.Environment(
+ loader=loader,
+ autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
+ )
+
+ if hs.config.form_secret is None:
+ raise ConfigError(
+ "Consent resource is enabled but form_secret is not set in "
+ "config file. It should be set to an arbitrary secret string.",
+ )
+
+ self._hmac_secret = hs.config.form_secret.encode("utf-8")
+
+ def render_GET(self, request):
+ self._async_render_GET(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_GET(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+
+ version = parse_string(request, "v",
+ default=self._default_consent_version)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ u = yield self.store.get_user_by_id(qualified_user_id)
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(
+ request, "%s.html" % (version,),
+ user=username, userhmac=userhmac, version=version,
+ has_consented=(u["consent_version"] == version),
+ )
+ except TemplateNotFound:
+ raise NotFoundError("Unknown policy version")
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_POST(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+ version = parse_string(request, "v", required=True)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ try:
+ yield self.store.user_set_consent_version(qualified_user_id, version)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(request, "success.html")
+ except TemplateNotFound:
+ raise NotFoundError("success.html not found")
+
+ def _render_template(self, request, template_name, **template_args):
+ # get_template checks for ".." so we don't need to worry too much
+ # about path traversal here.
+ template_html = self._jinja_env.get_template(
+ path.join(TEMPLATE_LANGUAGE, template_name)
+ )
+ html_bytes = template_html.render(**template_args).encode("utf8")
+
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _check_hash(self, userid, userhmac):
+ want_mac = hmac.new(
+ key=self._hmac_secret,
+ msg=userid,
+ digestmod=sha256,
+ ).hexdigest()
+
+ if not compare_digest(want_mac, userhmac):
+ raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index bd4fea5774..b9ee6e1c13 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -14,14 +14,16 @@
# limitations under the License.
-from twisted.web.resource import Resource
-from synapse.http.server import respond_with_json_bytes
+import logging
+
+from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
-from canonicaljson import encode_canonical_json
+
from OpenSSL import crypto
-import logging
+from twisted.web.resource import Resource
+from synapse.http.server import respond_with_json_bytes
logger = logging.getLogger(__name__)
@@ -49,7 +51,6 @@ class LocalKey(Resource):
"""
def __init__(self, hs):
- self.version_string = hs.version_string
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)
)
@@ -84,7 +85,6 @@ class LocalKey(Resource):
def render_GET(self, request):
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
def getChild(self, name, request):
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index a07224148c..3491fd2118 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -14,6 +14,7 @@
# limitations under the License.
from twisted.web.resource import Resource
+
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index be68d9a096..ec0ec7b431 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -14,13 +14,15 @@
# limitations under the License.
-from twisted.web.resource import Resource
-from synapse.http.server import respond_with_json_bytes
+import logging
+
+from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
-from canonicaljson import encode_canonical_json
-import logging
+from twisted.web.resource import Resource
+
+from synapse.http.server import respond_with_json_bytes
logger = logging.getLogger(__name__)
@@ -63,7 +65,6 @@ class LocalKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.version_string = hs.version_string
self.config = hs.config
self.clock = hs.clock
self.update_response_body(self.clock.time_msec())
@@ -115,5 +116,4 @@ class LocalKey(Resource):
self.update_response_body(time_now)
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9fe2013657..7d67e4b064 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -12,18 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import request_handler, respond_with_json_bytes
-from synapse.http.servlet import parse_integer, parse_json_object_from_request
-from synapse.api.errors import SynapseError, Codes
-from synapse.crypto.keyring import KeyLookupError
+import logging
+from io import BytesIO
+from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
+from synapse.api.errors import Codes, SynapseError
+from synapse.crypto.keyring import KeyLookupError
+from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
+from synapse.http.servlet import parse_integer, parse_json_object_from_request
-from io import BytesIO
-import logging
logger = logging.getLogger(__name__)
@@ -91,14 +91,14 @@ class RemoteKey(Resource):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
- self.version_string = hs.version_string
self.clock = hs.get_clock()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
@@ -123,7 +123,7 @@ class RemoteKey(Resource):
self.async_render_POST(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
content = parse_json_object_from_request(request)
@@ -137,6 +137,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ logger.debug("Federation denied with %s", server_name)
+ continue
+
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
@@ -213,7 +220,7 @@ class RemoteKey(Resource):
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
- except:
+ except Exception:
logger.exception("Failed to get key for %r", server_name)
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
@@ -232,5 +239,4 @@ class RemoteKey(Resource):
respond_with_json_bytes(
request, 200, result_io.getvalue(),
- version_string=self.version_string
)
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index 956bd5da75..f255f2883f 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -13,21 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json_bytes, finish_request
-
-from synapse.api.errors import (
- Codes, cs_error
-)
-
-from twisted.protocols.basic import FileSender
-from twisted.web import server, resource
-
import base64
-import simplejson as json
import logging
import os
import re
+from canonicaljson import json
+
+from twisted.protocols.basic import FileSender
+from twisted.web import resource, server
+
+from synapse.api.errors import Codes, cs_error
+from synapse.http.server import finish_request, respond_with_json_bytes
+
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index b9600f2167..65f4bd2910 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -13,22 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json, finish_request
-from synapse.api.errors import (
- cs_error, Codes, SynapseError
-)
+import logging
+import os
+import urllib
+
+from six.moves.urllib import parse as urlparse
from twisted.internet import defer
from twisted.protocols.basic import FileSender
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.http.server import finish_request, respond_with_json
+from synapse.util import logcontext
from synapse.util.stringutils import is_ascii
-import os
-
-import logging
-import urllib
-import urlparse
-
logger = logging.getLogger(__name__)
@@ -44,7 +42,7 @@ def parse_media_id(request):
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
- except:
+ except Exception:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
@@ -69,42 +67,133 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
- if upload_name:
- if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
- else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
- request.setHeader(
- b"Content-Length", b"%d" % (file_size,)
- )
+ add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield FileSender().beginFileTransfer(f, request)
+ yield logcontext.make_deferred_yieldable(
+ FileSender().beginFileTransfer(f, request)
+ )
finish_request(request)
else:
respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request (twisted.web.http.Request)
+ media_type (str): The media/content type.
+ file_size (int): Size in bytes of the media, if known.
+ upload_name (str): The name of the requested file, if any.
+ """
+ request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+ )
+
+ request.setHeader(
+ b"Content-Length", b"%d" % (file_size,)
+ )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request (twisted.web.http.Request)
+ responder (Responder|None)
+ media_type (str): The media/content type.
+ file_size (int|None): Size in bytes of the media. If not known it should be None
+ upload_name (str|None): The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ logger.debug("Responding to media request with responder %s")
+ add_file_headers(request, media_type, file_size, upload_name)
+ with responder:
+ yield responder.write_to_consumer(request)
+ finish_request(request)
+
+
+class Responder(object):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+ def write_to_consumer(self, consumer):
+ """Stream response into consumer
+
+ Args:
+ consumer (IConsumer)
+
+ Returns:
+ Deferred: Resolves once the response has finished being written
+ """
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FileInfo(object):
+ """Details about a requested/uploaded file.
+
+ Attributes:
+ server_name (str): The server name where the media originated from,
+ or None if local.
+ file_id (str): The local ID of the file. For local files this is the
+ same as the media_id
+ url_cache (bool): If the file is for the url preview cache
+ thumbnail (bool): Whether the file is a thumbnail or not.
+ thumbnail_width (int)
+ thumbnail_height (int)
+ thumbnail_method (str)
+ thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ """
+ def __init__(self, server_name, file_id, url_cache=False,
+ thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+ thumbnail_method=None, thumbnail_type=None):
+ self.server_name = server_name
+ self.file_id = file_id
+ self.url_cache = url_cache
+ self.thumbnail = thumbnail
+ self.thumbnail_width = thumbnail_width
+ self.thumbnail_height = thumbnail_height
+ self.thumbnail_method = thumbnail_method
+ self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fbfa85f74f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.http.servlet
+import logging
-from ._base import parse_media_id, respond_with_file, respond_404
+from twisted.internet import defer
from twisted.web.resource import Resource
-from synapse.http.server import request_handler, set_cors_headers
-
from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-import logging
+import synapse.http.servlet
+from synapse.http.server import set_cors_headers, wrap_json_request_handler
+
+from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
@@ -32,18 +32,17 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
+
+ # this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
@@ -57,59 +56,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id, name)
+ yield self.media_repo.get_local_media(request, media_id, name)
else:
- yield self._respond_remote_file(
- request, server_name, media_id, name
- )
-
- @defer.inlineCallbacks
- def _respond_local_file(self, request, media_id, name):
- media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
- respond_404(request)
- return
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_filepath(media_id)
- else:
- file_path = self.filepaths.local_media_filepath(media_id)
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
-
- @defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id, name):
- # don't forward requests for remote media if allow_remote is false
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
- if not allow_remote:
- logger.info(
- "Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
- )
- respond_404(request)
- return
-
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- filesystem_id = media_info["filesystem_id"]
- upload_name = name if name else media_info["upload_name"]
-
- file_path = self.filepaths.remote_media_filepath(
- server_name, filesystem_id
- )
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
+ yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d92b7ff337..c8586fa280 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -13,79 +13,201 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import os
+import re
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
+
+
+def _wrap_in_base_path(func):
+ """Takes a function that returns a relative path and turns it into an
+ absolute path based on the location of the primary media store
+ """
+ @functools.wraps(func)
+ def _wrapped(self, *args, **kwargs):
+ path = func(self, *args, **kwargs)
+ return os.path.join(self.base_path, path)
+
+ return _wrapped
class MediaFilePaths(object):
+ """Describes where files are stored on disk.
- def __init__(self, base_path):
- self.base_path = base_path
+ Most of the functions have a `*_rel` variant which returns a file path that
+ is relative to the base media store path. This is mainly used when we want
+ to write to the backup media store (when one is configured)
+ """
- def default_thumbnail(self, default_top_level, default_sub_type, width,
- height, content_type, method):
+ def __init__(self, primary_base_path):
+ self.base_path = primary_base_path
+
+ def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
+ height, content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "default_thumbnails", default_top_level,
+ "default_thumbnails", default_top_level,
default_sub_type, file_name
)
- def local_media_filepath(self, media_id):
+ default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
+
+ def local_media_filepath_rel(self, media_id):
return os.path.join(
- self.base_path, "local_content",
+ "local_content",
media_id[0:2], media_id[2:4], media_id[4:]
)
- def local_media_thumbnail(self, media_id, width, height, content_type,
- method):
+ local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+
+ def local_media_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "local_thumbnails",
+ "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
- def remote_media_filepath(self, server_name, file_id):
+ local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+
+ def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
- self.base_path, "remote_content", server_name,
+ "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:]
)
- def remote_media_thumbnail(self, server_name, file_id, width, height,
- content_type, method):
+ remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+
+ def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
+ content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
- self.base_path, "remote_thumbnail", server_name,
+ "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
+ remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)
- def url_cache_filepath(self, media_id):
- return os.path.join(
- self.base_path, "url_cache",
- media_id[0:2], media_id[2:4], media_id[4:]
- )
+ def url_cache_filepath_rel(self, media_id):
+ if NEW_FORMAT_ID_RE.match(media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ return os.path.join(
+ "url_cache",
+ media_id[:10], media_id[11:]
+ )
+ else:
+ return os.path.join(
+ "url_cache",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+
+ def url_cache_filepath_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id file"
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2],
+ ),
+ ]
+
+ def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
- def url_cache_thumbnail(self, media_id, width, height, content_type,
- method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
- return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
- )
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ file_name
+ )
+ else:
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ file_name
+ )
+
+ url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+
+ def url_cache_thumbnail_directory(self, media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id thumbnails"
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2],
+ ),
+ ]
diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py
index 66f2b6bd30..bdbd8d50dd 100644
--- a/synapse/rest/media/v1/identicon_resource.py
+++ b/synapse/rest/media/v1/identicon_resource.py
@@ -13,8 +13,11 @@
# limitations under the License.
from pydenticon import Generator
+
from twisted.web.resource import Resource
+from synapse.http.servlet import parse_integer
+
FOREGROUND = [
"rgb(45,79,255)",
"rgb(254,180,44)",
@@ -55,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request):
name = "/".join(request.postpath)
- width = int(request.args.get("width", [96])[0])
- height = int(request.args.get("height", [96])[0])
+ width = parse_integer(request, "width", default=96)
+ height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0ea1248ce6..30242c525a 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,45 +14,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer, threads
+import cgi
+import errno
+import logging
+import os
+import shutil
+
+from six import iteritems
+from six.moves.urllib import parse as urlparse
+
import twisted.internet.error
import twisted.web.http
+from twisted.internet import defer, threads
from twisted.web.resource import Resource
-from .upload_resource import UploadResource
-from .download_resource import DownloadResource
-from .thumbnail_resource import ThumbnailResource
-from .identicon_resource import IdenticonResource
-from .preview_url_resource import PreviewUrlResource
-from .filepath import MediaFilePaths
-from .thumbnailer import Thumbnailer
-
+from synapse.api.errors import (
+ FederationDeniedError,
+ HttpResponseException,
+ NotFoundError,
+ SynapseError,
+)
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
-from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
- NotFoundError
-
from synapse.util.async import Linearizer
-from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import is_ascii, random_string
-import os
-import errno
-import shutil
-
-import cgi
-import logging
-import urlparse
+from ._base import FileInfo, respond_404, respond_with_responder
+from .download_resource import DownloadResource
+from .filepath import MediaFilePaths
+from .identicon_resource import IdenticonResource
+from .media_storage import MediaStorage
+from .preview_url_resource import PreviewUrlResource
+from .storage_provider import StorageProviderWrapper
+from .thumbnail_resource import ThumbnailResource
+from .thumbnailer import Thumbnailer
+from .upload_resource import UploadResource
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
def __init__(self, hs):
+ self.hs = hs
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@@ -59,46 +67,90 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+ self.primary_base_path = hs.config.media_store_path
+ self.filepaths = MediaFilePaths(self.primary_base_path)
+
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
+
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+
+ # List of StorageProviders where we should search for media and
+ # potentially upload to.
+ storage_providers = []
+
+ for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ backend = clz(hs, provider_config)
+ provider = StorageProviderWrapper(
+ backend,
+ store_local=wrapper_config.store_local,
+ store_remote=wrapper_config.store_remote,
+ store_synchronous=wrapper_config.store_synchronous,
+ )
+ storage_providers.append(provider)
+
+ self.media_storage = MediaStorage(
+ self.hs, self.primary_base_path, self.filepaths, storage_providers,
+ )
self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
)
@defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
+
yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
+ local_media, remote_media, self.clock.time_msec()
)
- @staticmethod
- def _makedirs(filepath):
- dirname = os.path.dirname(filepath)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
+
+ Args:
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
+ """
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
+ """Store uploaded content for a local user and return the mxc URL
+
+ Args:
+ media_type(str): The content type of the file
+ upload_name(str): The name of the file
+ content: A file like object that is the content to store
+ content_length(int): The length of the content
+ auth_user(str): The user_id of the uploader
+
+ Returns:
+ Deferred[str]: The mxc url of the stored content
+ """
media_id = random_string(24)
- fname = self.filepaths.local_media_filepath(media_id)
- self._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ )
- # This shouldn't block for very long because the content will have
- # already been uploaded at this point.
- with open(fname, "wb") as f:
- f.write(content)
+ fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
@@ -110,131 +162,275 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
- media_info = {
- "media_type": media_type,
- "media_length": content_length,
- }
- yield self._generate_local_thumbnails(media_id, media_info)
+ yield self._generate_thumbnails(
+ None, media_id, media_id, media_type,
+ )
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
- def get_remote_media(self, server_name, media_id):
+ def get_local_media(self, request, media_id, name):
+ """Responds to reqests for local media, if exists, or returns 404.
+
+ Args:
+ request(twisted.web.http.Request)
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content.)
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info or media_info["quarantined_by"]:
+ respond_404(request)
+ return
+
+ self.mark_recently_accessed(None, media_id)
+
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ url_cache = media_info["url_cache"]
+
+ file_info = FileInfo(
+ None, media_id,
+ url_cache=url_cache,
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_media(self, request, server_name, media_id, name):
+ """Respond to requests for remote media.
+
+ Args:
+ request(twisted.web.http.Request)
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ self.mark_recently_accessed(server_name, media_id)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
+ key = (server_name, media_id)
+ with (yield self.remote_media_linearizer.queue(key)):
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # We deliberately stream the file outside the lock
+ if responder:
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+ else:
+ respond_404(request)
+
+ @defer.inlineCallbacks
+ def get_remote_media_info(self, server_name, media_id):
+ """Gets the media info associated with the remote file, downloading
+ if necessary.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[dict]: The media_info of the file
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
- media_info = yield self._get_remote_media_impl(server_name, media_id)
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # Ensure we actually use the responder so that it releases resources
+ if responder:
+ with responder:
+ pass
+
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
+ """Looks for media in local cache, if not there then attempt to
+ download from remote server.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[(Responder, media_info)]
+ """
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
- if not media_info:
- media_info = yield self._download_remote_file(
- server_name, media_id
- )
- elif media_info["quarantined_by"]:
- raise NotFoundError()
+
+ # file_id is the ID we use to track the file locally. If we've already
+ # seen the file then reuse the existing ID, otherwise genereate a new
+ # one.
+ if media_info:
+ file_id = media_info["filesystem_id"]
else:
- self.recently_accessed_remotes.add((server_name, media_id))
- yield self.store.update_cached_last_access_time(
- [(server_name, media_id)], self.clock.time_msec()
- )
- defer.returnValue(media_info)
+ file_id = random_string(24)
- @defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id):
- file_id = random_string(24)
+ file_info = FileInfo(server_name, file_id)
+
+ # If we have an entry in the DB, try and look for it
+ if media_info:
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ raise NotFoundError()
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ defer.returnValue((responder, media_info))
+
+ # Failed to find the file anywhere, lets download it.
- fname = self.filepaths.remote_media_filepath(
- server_name, file_id
+ media_info = yield self._download_remote_file(
+ server_name, media_id, file_id
)
- self._makedirs(fname)
- try:
- with open(fname, "wb") as f:
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ responder = yield self.media_storage.fetch_media(file_info)
+ defer.returnValue((responder, media_info))
+
+ @defer.inlineCallbacks
+ def _download_remote_file(self, server_name, media_id, file_id):
+ """Attempt to download the remote file from the given server name,
+ using the given file_id as the local id.
+
+ Args:
+ server_name (str): Originating server
+ media_id (str): The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ file_id (str): Local file ID
+
+ Returns:
+ Deferred[MediaInfo]
+ """
+
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ )
+
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ request_path = "/".join((
+ "/_matrix/media/v1/download", server_name, media_id,
+ ))
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
+ )
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+ except NotRetryingDestination:
+ logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ yield finish()
+
+ media_type = headers["Content-Type"][0]
+
+ time_now_ms = self.clock.time_msec()
+
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
try:
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
- # tell the remote server to 404 if it doesn't
- # recognise the server_name, to make sure we don't
- # end up with a routing loop.
- "allow_remote": "false",
- }
- )
- except twisted.internet.error.DNSLookupError as e:
- logger.warn("HTTP error fetching remote media %s/%s: %r",
- server_name, media_id, e)
- raise NotFoundError()
-
- except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
- if e.code == twisted.web.http.NOT_FOUND:
- raise SynapseError.from_http_response_exception(e)
- raise SynapseError(502, "Failed to fetch remote media")
-
- except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise
- except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
- raise SynapseError(502, "Failed to fetch remote media")
- except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise SynapseError(502, "Failed to fetch remote media")
-
- media_type = headers["Content-Type"][0]
- time_now_ms = self.clock.time_msec()
-
- content_disposition = headers.get("Content-Disposition", None)
- if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get("filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith("utf-8''"):
- upload_name = upload_name_utf8[7:]
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get("filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii
-
- if upload_name:
- upload_name = urlparse.unquote(upload_name)
- try:
- upload_name = upload_name.decode("utf-8")
- except UnicodeDecodeError:
- upload_name = None
- else:
- upload_name = None
-
- logger.info("Stored remote media in file %r", fname)
-
- yield self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
- except:
- os.remove(fname)
- raise
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
+ logger.info("Stored remote media in file %r", fname)
+
+ yield self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
media_info = {
"media_type": media_type,
@@ -244,8 +440,8 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_remote_thumbnails(
- server_name, media_id, media_info
+ yield self._generate_thumbnails(
+ server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -253,9 +449,8 @@ class MediaRepository(object):
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+ def _generate_thumbnail(self, thumbnailer, t_width, t_height,
t_method, t_type):
- thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -267,75 +462,125 @@ class MediaRepository(object):
return
if t_method == "crop":
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
- t_len = None
+ t_byte_source = None
- return t_len
+ return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type):
- input_path = self.filepaths.local_media_filepath(media_id)
-
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
+ t_method, t_type, url_cache):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ None, media_id, url_cache=url_cache,
+ ))
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ url_cache=url_cache,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=False,
+ ))
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
- media_type = media_info["media_type"]
+ def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+ url_cache=False):
+ """Generate and store thumbnails for an image.
+
+ Args:
+ server_name (str|None): The server name if remote media, else None if local
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content)
+ file_id (str): Local file ID
+ media_type (str): The content type of the file
+ url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ used exclusively by the url previewer
+
+ Returns:
+ Deferred[dict]: Dict with "width" and "height" keys of original image
+ """
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
- if url_cache:
- input_path = self.filepaths.url_cache_filepath(media_id)
- else:
- input_path = self.filepaths.local_media_filepath(media_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=url_cache,
+ ))
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
@@ -348,135 +593,68 @@ class MediaRepository(object):
)
return
- local_thumbnails = []
-
- def generate_thumbnails():
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- if url_cache:
- t_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+ # they have the same dimensions of a scaled one.
+ thumbnails = {}
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "crop":
+ thumbnails.setdefault((r_width, r_height, r_type), r_method)
+ elif r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
+ thumbnails[(t_width, t_height, r_type)] = r_method
+
+ # Now we generate the thumbnails for each dimension, store it
+ for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
+ # Generate the thumbnail
+ if t_method == "crop":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.crop,
+ t_width, t_height, t_type,
))
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- if url_cache:
- t_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ elif t_method == "scale":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.scale,
+ t_width, t_height, t_type,
))
+ else:
+ logger.error("Unrecognized method: %r", t_method)
+ continue
+
+ if not t_byte_source:
+ continue
+
+ try:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for l in local_thumbnails:
- yield self.store.store_local_thumbnail(*l)
-
- defer.returnValue({
- "width": m_width,
- "height": m_height,
- })
-
- @defer.inlineCallbacks
- def _generate_remote_thumbnails(self, server_name, media_id, media_info):
- media_type = media_info["media_type"]
- file_id = media_info["filesystem_id"]
- requirements = self._get_thumbnail_requirements(media_type)
- if not requirements:
- return
-
- remote_thumbnails = []
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
+ )
+ finally:
+ t_byte_source.close()
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- thumbnailer = Thumbnailer(input_path)
- m_width = thumbnailer.width
- m_height = thumbnailer.height
+ t_len = os.path.getsize(output_path)
- def generate_thumbnails():
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
- )
- return
-
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
+ # Write to database
+ if server_name:
+ yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
- ])
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
)
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- ])
-
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for r in remote_thumbnails:
- yield self.store.store_remote_media_thumbnail(*r)
+ else:
+ yield self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
defer.returnValue({
"width": m_width,
@@ -497,6 +675,8 @@ class MediaRepository(object):
logger.info("Deleting: %r", key)
+ # TODO: Should we delete from the backup store
+
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
@@ -571,7 +751,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+ self.putChild("thumbnail", ThumbnailResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+ self.putChild("preview_url", PreviewUrlResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..b25993fcb5
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import os
+import shutil
+import sys
+
+import six
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+from synapse.util.logcontext import make_deferred_yieldable
+
+from ._base import Responder
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ hs (synapse.server.Homeserver)
+ local_media_directory (str): Base path where we store media on disk
+ filepaths (MediaFilePaths)
+ storage_providers ([StorageProvider]): List of StorageProvider that are
+ used to fetch and store files.
+ """
+
+ def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ self.hs = hs
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+
+ @defer.inlineCallbacks
+ def store_file(self, source, file_info):
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info (FileInfo): Info about the file to store
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, f,
+ ))
+ yield finish_cb()
+
+ defer.returnValue(fname)
+
+ @contextlib.contextmanager
+ def store_into_file(self, file_info):
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns a Deferred.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info (FileInfo): Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ yield finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ finished_called = [False]
+
+ @defer.inlineCallbacks
+ def finish():
+ for provider in self.storage_providers:
+ yield provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ try:
+ with open(fname, "wb") as f:
+ yield f, fname, finish
+ except Exception:
+ t, v, tb = sys.exc_info()
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+ six.reraise(t, v, tb)
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ @defer.inlineCallbacks
+ def fetch_media(self, file_info):
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[Responder|None]: Returns a Responder if the file was found,
+ otherwise None.
+ """
+
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(FileResponder(open(local_path, "rb")))
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ defer.returnValue(res)
+
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def ensure_media_is_in_local_cache(self, file_info):
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[str]: Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(local_path)
+
+ dirname = os.path.dirname(local_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(
+ open(local_path, "w"), self.hs.get_reactor())
+ yield res.write_to_consumer(consumer)
+ yield consumer.wait()
+ defer.returnValue(local_path)
+
+ raise Exception("file could not be found")
+
+ def _file_info_to_path(self, file_info):
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ str
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id,
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.local_media_filepath_rel(
+ file_info.file_id,
+ )
+
+
+def _write_file_synchronously(source, dest):
+ """Write `source` to the file like `dest` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object that's to be written
+ dest: A file like object to be written to
+ """
+ source.seek(0) # Ensure we read from the start of the file
+ shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file (file): A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+ def __init__(self, open_file):
+ self.open_file = open_file
+
+ def write_to_consumer(self, consumer):
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b81a336c5d..b70b15c4c2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,80 +12,98 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+
+from six import string_types
+from six.moves import urllib_parse as urlparse
+
+from canonicaljson import json
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import (
- SynapseError, Codes,
-)
-from synapse.util.stringutils import random_string
-from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
- request_handler, respond_with_json_bytes
+ respond_with_json,
+ respond_with_json_bytes,
+ wrap_json_request_handler,
)
+from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred
-from synapse.util.stringutils import is_ascii
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.util.stringutils import is_ascii, random_string
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
+from ._base import FileInfo
-import logging
logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
self.clock = hs.get_clock()
- self.version_string = hs.version_string
self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
+ self.primary_base_path = media_repo.primary_base_path
+ self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
- # simple memory cache mapping urls to OG metadata
- self.cache = ExpiringCache(
+ # memory cache mapping urls to an ObservableDeferred returning
+ # JSON-encoded OG metadata
+ self._cache = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
- self.cache.start()
+ self._cache.start()
- self.downloads = {}
+ self._cleaner_loop = self.clock.looping_call(
+ self._expire_url_cache_data, 10 * 1000
+ )
+
+ def render_OPTIONS(self, request):
+ return respond_with_json(request, 200, {}, send_cors=True)
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
- url = request.args.get("url")[0]
+ url = parse_string(request, "url")
if "ts" in request.args:
- ts = int(request.args.get("ts")[0])
+ ts = parse_integer(request, "ts")
else:
ts = self.clock.time_msec()
+ # XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
@@ -118,53 +136,62 @@ class PreviewUrlResource(Resource):
Codes.UNKNOWN
)
- # first check the memory cache - good to handle all the clients on this
- # HS thundering away to preview the same URL at the same time.
- og = self.cache.get(url)
- if og:
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
- return
+ # the in-memory cache:
+ # * ensures that only one request is active at a time
+ # * takes load off the DB for the thundering herds
+ # * also caches any failures (unlike the DB) so we don't keep
+ # requesting the same endpoint
+
+ observable = self._cache.get(url)
+
+ if not observable:
+ download = run_in_background(
+ self._do_preview,
+ url, requester.user, ts,
+ )
+ observable = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
+ self._cache[url] = observable
+ else:
+ logger.info("Returning cached response")
+
+ og = yield make_deferred_yieldable(observable.observe())
+ respond_with_json_bytes(request, 200, og, send_cors=True)
- # then check the URL cache in the DB (which will also provide us with
+ @defer.inlineCallbacks
+ def _do_preview(self, url, user, ts):
+ """Check the db, and download the URL and build a preview
+
+ Args:
+ url (str):
+ user (str):
+ ts (int):
+
+ Returns:
+ Deferred[str]: json-encoded og data
+ """
+ # check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result and
- cache_result["download_ts"] + cache_result["expires"] > ts and
+ cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
- respond_with_json_bytes(
- request, 200, cache_result["og"].encode('utf-8'),
- send_cors=True
- )
+ defer.returnValue(cache_result["og"])
return
- # Ensure only one download for a given URL is active at a time
- download = self.downloads.get(url)
- if download is None:
- download = self._download_url(url, requester.user)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.downloads[url] = download
-
- @download.addBoth
- def callback(media_info):
- del self.downloads[url]
- return media_info
- media_info = yield download.observe()
-
- # FIXME: we should probably update our cache now anyway, so that
- # even if the OG calculation raises, we don't keep hammering on the
- # remote server. For now, leave it uncached to aid debugging OG
- # calculation problems
+ media_info = yield self._download_url(url, user)
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
- dims = yield self.media_repo._generate_local_thumbnails(
- media_info['filesystem_id'], media_info, url_cache=True,
+ file_id = media_info['filesystem_id']
+ dims = yield self.media_repo._generate_thumbnails(
+ None, file_id, file_id, media_info["media_type"],
+ url_cache=True,
)
og = {
@@ -204,13 +231,15 @@ class PreviewUrlResource(Resource):
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
- _rebase_url(og['og:image'], media_info['uri']), requester.user
+ _rebase_url(og['og:image'], media_info['uri']), user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
- dims = yield self.media_repo._generate_local_thumbnails(
- image_info['filesystem_id'], image_info, url_cache=True,
+ file_id = image_info['filesystem_id']
+ dims = yield self.media_repo._generate_thumbnails(
+ None, file_id, file_id, image_info["media_type"],
+ url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -231,21 +260,20 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- # store OG in ephemeral in-memory cache
- self.cache[url] = og
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
- media_info["expires"],
- json.dumps(og),
+ media_info["expires"] + media_info["created_ts"],
+ jsonog,
media_info["filesystem_id"],
media_info["created_ts"],
)
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+ defer.returnValue(jsonog)
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -253,21 +281,36 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- # XXX: horrible duplication with base_resource's _download_remote_file()
- file_id = random_string(24)
+ file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fname = self.filepaths.url_cache_filepath(file_id)
- self.media_repo._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=file_id,
+ url_cache=True,
+ )
- try:
- with open(fname, "wb") as f:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
+ except Exception as e:
# FIXME: pass through 404s and other error messages nicely
+ logger.warn("Error downloading %s: %r", url, e)
+ raise SynapseError(
+ 500, "Failed to download content: %s" % (
+ traceback.format_exception_only(sys.exc_type, e),
+ ),
+ Codes.UNKNOWN,
+ )
+ yield finish()
- media_type = headers["Content-Type"][0]
+ try:
+ if "Content-Type" in headers:
+ media_type = headers["Content-Type"][0]
+ else:
+ media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
@@ -307,11 +350,11 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
- os.remove(fname)
- raise SynapseError(
- 500, ("Failed to download content: %s" % e),
- Codes.UNKNOWN
- )
+ logger.error("Error handling downloaded %s: %r", url, e)
+ # TODO: we really ought to delete the downloaded file in this
+ # case, since we won't have recorded it in the db, and will
+ # therefore not expire it.
+ raise
defer.returnValue({
"media_type": media_type,
@@ -328,6 +371,95 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
+ @defer.inlineCallbacks
+ def _expire_url_cache_data(self):
+ """Clean up expired url cache content, media and thumbnails.
+ """
+ # TODO: Delete from backup media store
+
+ now = self.clock.time_msec()
+
+ logger.info("Running url preview cache expiry")
+
+ if not (yield self.store.has_completed_background_updates()):
+ logger.info("Still running DB updates; skipping expiry")
+ return
+
+ # First we delete expired url cache entries
+ media_ids = yield self.store.get_expired_url_cache(now)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ yield self.store.delete_url_cache(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d entries from url cache", len(removed_media))
+
+ # Now we delete old images associated with the url cache.
+ # These may be cached for a bit on the client (i.e., they
+ # may have a room open with a preview url thing open).
+ # So we wait a couple of days before deleting, just in case.
+ expire_before = now - 2 * 24 * 60 * 60 * 1000
+ media_ids = yield self.store.get_url_cache_media_before(expire_before)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
+ try:
+ shutil.rmtree(thumbnail_dir)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except Exception:
+ pass
+
+ yield self.store.delete_url_cache_media(removed_media)
+
+ logger.info("Deleted %d media from url cache", len(removed_media))
+
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
@@ -425,7 +557,14 @@ def _calc_og(tree, media_uri):
from lxml import etree
TAGS_TO_REMOVE = (
- "header", "nav", "aside", "footer", "script", "style", etree.Comment
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
@@ -452,8 +591,8 @@ def _iterate_over_text(tree, *tags_to_ignore):
# to be returned.
elements = iter([tree])
while True:
- el = elements.next()
- if isinstance(el, basestring):
+ el = next(elements)
+ if isinstance(el, string_types):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..7b9f8b4d79
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+
+from twisted.internet import defer, threads
+
+from synapse.config._base import Config
+from synapse.util.logcontext import run_in_background
+
+from .media_storage import FileResponder
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+ def store_file(self, path, file_info):
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred
+ """
+ pass
+
+ def fetch(self, path, file_info):
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred(Responder): Returns a Responder if the provider has the file,
+ otherwise returns None.
+ """
+ pass
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend (StorageProvider)
+ store_local (bool): Whether to store new local files or not.
+ store_synchronous (bool): Whether to wait for file to be successfully
+ uploaded, or todo the upload in the backgroud.
+ store_remote (bool): Whether remote media should be uploaded
+ """
+ def __init__(self, backend, store_local, store_synchronous, store_remote):
+ self.backend = backend
+ self.store_local = store_local
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def store_file(self, path, file_info):
+ if not file_info.server_name and not self.store_local:
+ return defer.succeed(None)
+
+ if file_info.server_name and not self.store_remote:
+ return defer.succeed(None)
+
+ if self.store_synchronous:
+ return self.backend.store_file(path, file_info)
+ else:
+ # TODO: Handle errors.
+ def store():
+ try:
+ return self.backend.store_file(path, file_info)
+ except Exception:
+ logger.exception("Error storing file")
+ run_in_background(store)
+ return defer.succeed(None)
+
+ def fetch(self, path, file_info):
+ return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ hs (HomeServer)
+ config: The config returned by `parse_config`.
+ """
+
+ def __init__(self, hs, config):
+ self.cache_directory = hs.config.media_store_path
+ self.base_directory = config
+
+ def store_file(self, path, file_info):
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ return threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
+ def fetch(self, path, file_info):
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
+
+ @staticmethod
+ def parse_config(config):
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 68d56b2b10..5305e9175f 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,15 +14,22 @@
# limitations under the License.
-from ._base import parse_media_id, respond_404, respond_with_file
-from twisted.web.resource import Resource
-from synapse.http.servlet import parse_string, parse_integer
-from synapse.http.server import request_handler, set_cors_headers
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.http.server import set_cors_headers, wrap_json_request_handler
+from synapse.http.servlet import parse_integer, parse_string
+
+from ._base import (
+ FileInfo,
+ parse_media_id,
+ respond_404,
+ respond_with_file,
+ respond_with_responder,
+)
logger = logging.getLogger(__name__)
@@ -30,22 +37,21 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
+ self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
@@ -64,6 +70,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,20 +82,20 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -96,42 +103,39 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
-
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path)
- else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
)
+ t_type = file_info.thumbnail_type
+ t_length = thumbnail_info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
+ else:
+ logger.info("Couldn't find any generated thumbnails")
+ respond_404(request)
+
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method,
desired_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@@ -141,46 +145,43 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- yield respond_with_file(request, desired_type, file_path)
- return
-
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type
+ media_id, desired_width, desired_height, desired_method, desired_type,
+ url_cache=media_info["url_cache"],
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -195,14 +196,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, desired_width, desired_height,
- desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -213,22 +224,16 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead.
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -238,59 +243,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_id = thumbnail_info["filesystem_id"]
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
- )
-
- @defer.inlineCallbacks
- def _respond_default_thumbnail(self, request, media_info, width, height,
- method, m_type):
- # XXX: how is this meant to work? store.get_default_thumbnails
- # appears to always return [] so won't this always 404?
- media_type = media_info["media_type"]
- top_level_type = media_type.split("/")[0]
- sub_type = media_type.split("/")[-1].split(";")[0]
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, sub_type,
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, "_default",
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- "_default", "_default",
- )
- if not thumbnail_infos:
+ logger.info("Failed to find any generated thumbnails")
respond_404(request)
- return
-
- thumbnail_info = self._select_thumbnail(
- width, height, "crop", m_type, thumbnail_infos
- )
-
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- t_length = thumbnail_info["thumbnail_length"]
-
- file_path = self.filepaths.default_thumbnail(
- top_level_type, sub_type, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3868d4f65f..a4b26c2587 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import PIL.Image as Image
+import logging
from io import BytesIO
-import logging
+import PIL.Image as Image
logger = logging.getLogger(__name__)
@@ -50,12 +50,16 @@ class Thumbnailer(object):
else:
return ((max_height * self.width) // self.height, max_height)
- def scale(self, output_path, width, height, output_type):
- """Rescales the image to the given dimensions"""
+ def scale(self, width, height, output_type):
+ """Rescales the image to the given dimensions.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
+ """
scaled = self.image.resize((width, height), Image.ANTIALIAS)
- return self.save_image(scaled, output_type, output_path)
+ return self._encode_image(scaled, output_type)
- def crop(self, output_path, width, height, output_type):
+ def crop(self, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -65,6 +69,9 @@ class Thumbnailer(object):
Args:
max_width: The largest possible width.
max_height: The larget possible height.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
@@ -82,13 +89,9 @@ class Thumbnailer(object):
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
- return self.save_image(cropped, output_type, output_path)
+ return self._encode_image(cropped, output_type)
- def save_image(self, output_image, output_type, output_path):
+ def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
- output_bytes = output_bytes_io.getvalue()
- with open(output_path, "wb") as output_file:
- output_file.write(output_bytes)
- logger.info("Stored thumbnail in file %r", output_path)
- return len(output_bytes)
+ return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 4ab33f73bf..9b22d204a6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json, request_handler
-
-from synapse.api.errors import SynapseError
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
-
from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.api.errors import SynapseError
+from synapse.http.server import respond_with_json, wrap_json_request_handler
+from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
@@ -40,7 +39,6 @@ class UploadResource(Resource):
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_POST(self, request):
@@ -51,7 +49,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
@@ -68,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
- upload_name = request.args.get("filename", None)
+ upload_name = parse_string(request, "filename")
if upload_name:
try:
- upload_name = upload_name[0].decode('UTF-8')
+ upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@@ -81,19 +79,19 @@ class UploadResource(Resource):
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders("Content-Type")[0]
+ media_type = headers.getRawHeaders(b"Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
- # if headers.hasHeader("Content-Disposition"):
- # disposition = headers.getRawHeaders("Content-Disposition")[0]
+ # if headers.hasHeader(b"Content-Disposition"):
+ # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
- media_type, upload_name, request.content.read(),
+ media_type, upload_name, request.content,
content_length, requester.user
)
diff --git a/synapse/secrets.py b/synapse/secrets.py
new file mode 100644
index 0000000000..f397daaa5e
--- /dev/null
+++ b/synapse/secrets.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Injectable secrets module for Synapse.
+
+See https://docs.python.org/3/library/secrets.html#module-secrets for the API
+used in Python 3.6, and the API emulated in Python 2.7.
+"""
+
+import six
+
+if six.PY3:
+ import secrets
+
+ def Secrets():
+ return secrets
+
+
+else:
+
+ import os
+ import binascii
+
+ class Secrets(object):
+ def token_bytes(self, nbytes=32):
+ return os.urandom(nbytes)
+
+ def token_hex(self, nbytes=32):
+ return binascii.hexlify(self.token_bytes(nbytes))
diff --git a/synapse/server.py b/synapse/server.py
index a38e5179e0..140be9ebe8 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,32 +31,55 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
-from synapse.federation import initialize_http_replication
+from synapse.events.spamcheck import SpamChecker
+from synapse.federation.federation_client import FederationClient
+from synapse.federation.federation_server import (
+ FederationHandlerRegistry,
+ FederationServer,
+)
from synapse.federation.send_queue import FederationRemoteSendQueue
-from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
+from synapse.federation.transport.client import TransportLayerClient
+from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
+from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
-from synapse.handlers.devicemessage import DeviceMessageHandler
+from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler
+from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
+from synapse.handlers.events import EventHandler, EventStreamHandler
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.initial_sync import InitialSyncHandler
+from synapse.handlers.message import EventCreationHandler, MessageHandler
+from synapse.handlers.pagination import PaginationHandler
from synapse.handlers.presence import PresenceHandler
+from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.read_marker import ReadMarkerHandler
+from synapse.handlers.receipts import ReceiptsHandler
+from synapse.handlers.room import RoomContextHandler, RoomCreationHandler
from synapse.handlers.room_list import RoomListHandler
+from synapse.handlers.room_member import RoomMemberMasterHandler
+from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
+from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
-from synapse.handlers.events import EventHandler, EventStreamHandler
-from synapse.handlers.initial_sync import InitialSyncHandler
-from synapse.handlers.receipts import ReceiptsHandler
-from synapse.handlers.read_marker import ReadMarkerHandler
-from synapse.handlers.user_directory import UserDirectoyHandler
-from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
+from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
-from synapse.rest.media.v1.media_repository import MediaRepository
-from synapse.state import StateHandler
+from synapse.rest.media.v1.media_repository import (
+ MediaRepository,
+ MediaRepositoryResource,
+)
+from synapse.secrets import Secrets
+from synapse.server_notices.server_notices_manager import ServerNoticesManager
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
+from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -82,21 +105,21 @@ class HomeServer(object):
which must be implemented by the subclass. This code may call any of the
required "get" methods on the instance to obtain the sub-dependencies that
one requires.
+
+ Attributes:
+ config (synapse.config.homeserver.HomeserverConfig):
"""
DEPENDENCIES = [
- 'config',
- 'clock',
'http_client',
'db_pool',
- 'persistence_service',
- 'replication_layer',
- 'datastore',
+ 'federation_client',
+ 'federation_server',
'handlers',
- 'v1auth',
'auth',
- 'rest_servlet_factory',
+ 'room_creation_handler',
'state_handler',
+ 'state_resolution_handler',
'presence_handler',
'sync_handler',
'typing_handler',
@@ -111,19 +134,12 @@ class HomeServer(object):
'application_service_scheduler',
'application_service_handler',
'device_message_handler',
+ 'profile_handler',
+ 'event_creation_handler',
+ 'deactivate_account_handler',
+ 'set_password_handler',
'notifier',
- 'distributor',
- 'client_resource',
- 'resource_for_federation',
- 'resource_for_static_content',
- 'resource_for_web_client',
- 'resource_for_content_repo',
- 'resource_for_server_key',
- 'resource_for_server_key_v2',
- 'resource_for_media_repository',
- 'resource_for_metrics',
'event_sources',
- 'ratelimiter',
'keyring',
'pusherpool',
'event_builder_factory',
@@ -131,6 +147,7 @@ class HomeServer(object):
'http_client_context_factory',
'simple_http_client',
'media_repository',
+ 'media_repository_resource',
'federation_transport_client',
'federation_sender',
'receipts_handler',
@@ -139,17 +156,34 @@ class HomeServer(object):
'read_marker_handler',
'action_generator',
'user_directory_handler',
+ 'groups_local_handler',
+ 'groups_server_handler',
+ 'groups_attestation_signing',
+ 'groups_attestation_renewer',
+ 'secrets',
+ 'spam_checker',
+ 'room_member_handler',
+ 'federation_registry',
+ 'server_notices_manager',
+ 'server_notices_sender',
+ 'message_handler',
+ 'pagination_handler',
+ 'room_context_handler',
]
- def __init__(self, hostname, **kwargs):
+ def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
"""
+ if not reactor:
+ from twisted.internet import reactor
+
+ self._reactor = reactor
self.hostname = hostname
self._building = {}
- self.clock = Clock()
+ self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
@@ -162,6 +196,12 @@ class HomeServer(object):
self.datastore = DataStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
+ def get_reactor(self):
+ """
+ Fetch the Twisted reactor in use by this HomeServer.
+ """
+ return self._reactor
+
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
@@ -172,8 +212,26 @@ class HomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
- def build_replication_layer(self):
- return initialize_http_replication(self)
+ def get_clock(self):
+ return self.clock
+
+ def get_datastore(self):
+ return self.datastore
+
+ def get_config(self):
+ return self.config
+
+ def get_distributor(self):
+ return self.distributor
+
+ def get_ratelimiter(self):
+ return self.ratelimiter
+
+ def build_federation_client(self):
+ return FederationClient(self)
+
+ def build_federation_server(self):
+ return FederationServer(self)
def build_handlers(self):
return Handlers(self)
@@ -194,18 +252,15 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
- def build_v1auth(self):
- orf = Auth(self)
- # Matrix spec makes no reference to what HTTP status code is returned,
- # but the V1 API uses 403 where it means 401, and the webclient
- # relies on this behaviour, so V1 gets its own copy of the auth
- # with backwards compat behaviour.
- orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
- return orf
+ def build_room_creation_handler(self):
+ return RoomCreationHandler(self)
def build_state_handler(self):
return StateHandler(self)
+ def build_state_resolution_handler(self):
+ return StateResolutionHandler(self)
+
def build_presence_handler(self):
return PresenceHandler(self)
@@ -222,7 +277,7 @@ class HomeServer(object):
return AuthHandler(self)
def build_macaroon_generator(self):
- return MacaroonGeneartor(self)
+ return MacaroonGenerator(self)
def build_device_handler(self):
return DeviceHandler(self)
@@ -251,6 +306,18 @@ class HomeServer(object):
def build_initial_sync_handler(self):
return InitialSyncHandler(self)
+ def build_profile_handler(self):
+ return ProfileHandler(self)
+
+ def build_event_creation_handler(self):
+ return EventCreationHandler(self)
+
+ def build_deactivate_account_handler(self):
+ return DeactivateAccountHandler(self)
+
+ def build_set_password_handler(self):
+ return SetPasswordHandler(self)
+
def build_event_sources(self):
return EventSources(self)
@@ -277,9 +344,32 @@ class HomeServer(object):
return adbapi.ConnectionPool(
name,
+ cp_reactor=self.get_reactor(),
**self.db_config.get("args", {})
)
+ def get_db_conn(self, run_new_connection=True):
+ """Makes a new connection to the database, skipping the db pool
+
+ Returns:
+ Connection: a connection object implementing the PEP-249 spec
+ """
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def build_media_repository_resource(self):
+ # build the media repo resource. This indirects through the HomeServer
+ # to ensure that we only have a single instance of
+ return MediaRepositoryResource(self)
+
def build_media_repository(self):
return MediaRepository(self)
@@ -307,7 +397,52 @@ class HomeServer(object):
return ActionGenerator(self)
def build_user_directory_handler(self):
- return UserDirectoyHandler(self)
+ return UserDirectoryHandler(self)
+
+ def build_groups_local_handler(self):
+ return GroupsLocalHandler(self)
+
+ def build_groups_server_handler(self):
+ return GroupsServerHandler(self)
+
+ def build_groups_attestation_signing(self):
+ return GroupAttestationSigning(self)
+
+ def build_groups_attestation_renewer(self):
+ return GroupAttestionRenewer(self)
+
+ def build_secrets(self):
+ return Secrets()
+
+ def build_spam_checker(self):
+ return SpamChecker(self)
+
+ def build_room_member_handler(self):
+ if self.config.worker_app:
+ return RoomMemberWorkerHandler(self)
+ return RoomMemberMasterHandler(self)
+
+ def build_federation_registry(self):
+ return FederationHandlerRegistry()
+
+ def build_server_notices_manager(self):
+ if self.config.worker_app:
+ raise Exception("Workers cannot send server notices")
+ return ServerNoticesManager(self)
+
+ def build_server_notices_sender(self):
+ if self.config.worker_app:
+ return WorkerServerNoticesSender(self)
+ return ServerNoticesSender(self)
+
+ def build_message_handler(self):
+ return MessageHandler(self)
+
+ def build_pagination_handler(self):
+ return PaginationHandler(self)
+
+ def build_room_context_handler(self):
+ return RoomContextHandler(self)
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9570df5537..ce28486233 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,12 +1,25 @@
import synapse.api.auth
+import synapse.config.homeserver
+import synapse.federation.transaction_queue
+import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
+import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
-import synapse.storage
+import synapse.handlers.set_password
+import synapse.rest.media.v1.media_repository
+import synapse.server_notices.server_notices_manager
+import synapse.server_notices.server_notices_sender
import synapse.state
+import synapse.storage
+
class HomeServer(object):
+ @property
+ def config(self) -> synapse.config.homeserver.HomeServerConfig:
+ pass
+
def get_auth(self) -> synapse.api.auth.Auth:
pass
@@ -27,3 +40,36 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
+
+ def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+ pass
+
+ def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+ pass
+
+ def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
+ pass
+
+ def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
+ pass
+
+ def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+ pass
+
+ def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
+ pass
+
+ def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+ pass
+
+ def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+ pass
+
+ def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+ pass
+
+ def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
+ pass
+
+ def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ pass
diff --git a/synapse/server_notices/__init__.py b/synapse/server_notices/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/server_notices/__init__.py
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
new file mode 100644
index 0000000000..5e3044d164
--- /dev/null
+++ b/synapse/server_notices/consent_server_notices.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from six import iteritems, string_types
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.api.urls import ConsentURIBuilder
+from synapse.config import ConfigError
+from synapse.types import get_localpart_from_id
+
+logger = logging.getLogger(__name__)
+
+
+class ConsentServerNotices(object):
+ """Keeps track of whether we need to send users server_notices about
+ privacy policy consent, and sends one if we do.
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ self._server_notices_manager = hs.get_server_notices_manager()
+ self._store = hs.get_datastore()
+
+ self._users_in_progress = set()
+
+ self._current_consent_version = hs.config.user_consent_version
+ self._server_notice_content = hs.config.user_consent_server_notice_content
+ self._send_to_guests = hs.config.user_consent_server_notice_to_guests
+
+ if self._server_notice_content is not None:
+ if not self._server_notices_manager.is_enabled():
+ raise ConfigError(
+ "user_consent configuration requires server notices, but "
+ "server notices are not enabled.",
+ )
+ if 'body' not in self._server_notice_content:
+ raise ConfigError(
+ "user_consent server_notice_consent must contain a 'body' "
+ "key.",
+ )
+
+ self._consent_uri_builder = ConsentURIBuilder(hs.config)
+
+ @defer.inlineCallbacks
+ def maybe_send_server_notice_to_user(self, user_id):
+ """Check if we need to send a notice to this user, and does so if so
+
+ Args:
+ user_id (str): user to check
+
+ Returns:
+ Deferred
+ """
+ if self._server_notice_content is None:
+ # not enabled
+ return
+
+ # make sure we don't send two messages to the same user at once
+ if user_id in self._users_in_progress:
+ return
+ self._users_in_progress.add(user_id)
+ try:
+ u = yield self._store.get_user_by_id(user_id)
+
+ if u["is_guest"] and not self._send_to_guests:
+ # don't send to guests
+ return
+
+ if u["consent_version"] == self._current_consent_version:
+ # user has already consented
+ return
+
+ if u["consent_server_notice_sent"] == self._current_consent_version:
+ # we've already sent a notice to the user
+ return
+
+ # need to send a message.
+ try:
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ get_localpart_from_id(user_id),
+ )
+ content = copy_with_str_subst(
+ self._server_notice_content, {
+ 'consent_uri': consent_uri,
+ },
+ )
+ yield self._server_notices_manager.send_notice(
+ user_id, content,
+ )
+ yield self._store.user_set_consent_server_notice_sent(
+ user_id, self._current_consent_version,
+ )
+ except SynapseError as e:
+ logger.error("Error sending server notice about user consent: %s", e)
+ finally:
+ self._users_in_progress.remove(user_id)
+
+
+def copy_with_str_subst(x, substitutions):
+ """Deep-copy a structure, carrying out string substitions on any strings
+
+ Args:
+ x (object): structure to be copied
+ substitutions (object): substitutions to be made - passed into the
+ string '%' operator
+
+ Returns:
+ copy of x
+ """
+ if isinstance(x, string_types):
+ return x % substitutions
+ if isinstance(x, dict):
+ return {
+ k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
+ }
+ if isinstance(x, (list, tuple)):
+ return [copy_with_str_subst(y) for y in x]
+
+ # assume it's uninterested and can be shallow-copied.
+ return x
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
new file mode 100644
index 0000000000..a26deace53
--- /dev/null
+++ b/synapse/server_notices/server_notices_manager.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
+from synapse.types import create_requester
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class ServerNoticesManager(object):
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ self._store = hs.get_datastore()
+ self._config = hs.config
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self._event_creation_handler = hs.get_event_creation_handler()
+ self._is_mine_id = hs.is_mine_id
+
+ def is_enabled(self):
+ """Checks if server notices are enabled on this server.
+
+ Returns:
+ bool
+ """
+ return self._config.server_notices_mxid is not None
+
+ @defer.inlineCallbacks
+ def send_notice(self, user_id, event_content):
+ """Send a notice to the given user
+
+ Creates the server notices room, if none exists.
+
+ Args:
+ user_id (str): mxid of user to send event to.
+ event_content (dict): content of event to send
+
+ Returns:
+ Deferred[None]
+ """
+ room_id = yield self.get_notice_room_for_user(user_id)
+
+ system_mxid = self._config.server_notices_mxid
+ requester = create_requester(system_mxid)
+
+ logger.info("Sending server notice to %s", user_id)
+
+ yield self._event_creation_handler.create_and_send_nonmember_event(
+ requester, {
+ "type": EventTypes.Message,
+ "room_id": room_id,
+ "sender": system_mxid,
+ "content": event_content,
+ },
+ ratelimit=False,
+ )
+
+ @cachedInlineCallbacks()
+ def get_notice_room_for_user(self, user_id):
+ """Get the room for notices for a given user
+
+ If we have not yet created a notice room for this user, create it
+
+ Args:
+ user_id (str): complete user id for the user we want a room for
+
+ Returns:
+ str: room id of notice room.
+ """
+ if not self.is_enabled():
+ raise Exception("Server notices not enabled")
+
+ assert self._is_mine_id(user_id), \
+ "Cannot send server notices to remote users"
+
+ rooms = yield self._store.get_rooms_for_user_where_membership_is(
+ user_id, [Membership.INVITE, Membership.JOIN],
+ )
+ system_mxid = self._config.server_notices_mxid
+ for room in rooms:
+ # it's worth noting that there is an asymmetry here in that we
+ # expect the user to be invited or joined, but the system user must
+ # be joined. This is kinda deliberate, in that if somebody somehow
+ # manages to invite the system user to a room, that doesn't make it
+ # the server notices room.
+ user_ids = yield self._store.get_users_in_room(room.room_id)
+ if system_mxid in user_ids:
+ # we found a room which our user shares with the system notice
+ # user
+ logger.info("Using room %s", room.room_id)
+ defer.returnValue(room.room_id)
+
+ # apparently no existing notice room: create a new one
+ logger.info("Creating server notices room for %s", user_id)
+
+ # see if we want to override the profile info for the server user.
+ # note that if we want to override either the display name or the
+ # avatar, we have to use both.
+ join_profile = None
+ if (
+ self._config.server_notices_mxid_display_name is not None or
+ self._config.server_notices_mxid_avatar_url is not None
+ ):
+ join_profile = {
+ "displayname": self._config.server_notices_mxid_display_name,
+ "avatar_url": self._config.server_notices_mxid_avatar_url,
+ }
+
+ requester = create_requester(system_mxid)
+ info = yield self._room_creation_handler.create_room(
+ requester,
+ config={
+ "preset": RoomCreationPreset.PRIVATE_CHAT,
+ "name": self._config.server_notices_room_name,
+ "power_level_content_override": {
+ "users_default": -10,
+ },
+ "invite": (user_id,)
+ },
+ ratelimit=False,
+ creator_join_profile=join_profile,
+ )
+ room_id = info['room_id']
+
+ logger.info("Created server notices room %s for %s", room_id, user_id)
+ defer.returnValue(room_id)
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
new file mode 100644
index 0000000000..5d23965f34
--- /dev/null
+++ b/synapse/server_notices/server_notices_sender.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.server_notices.consent_server_notices import ConsentServerNotices
+
+
+class ServerNoticesSender(object):
+ """A centralised place which sends server notices automatically when
+ Certain Events take place
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ # todo: it would be nice to make this more dynamic
+ self._consent_server_notices = ConsentServerNotices(hs)
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ # The synchrotrons use a stubbed version of ServerNoticesSender, so
+ # we check for notices to send to the user in on_user_ip as well as
+ # in on_user_syncing
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
new file mode 100644
index 0000000000..4a133026c3
--- /dev/null
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+
+class WorkerServerNoticesSender(object):
+ """Stub impl of ServerNoticesSender which does nothing"""
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return defer.succeed(None)
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ raise AssertionError("on_user_ip unexpectedly called on worker")
diff --git a/synapse/state.py b/synapse/state.py
index 390799fbd5..033f55d967 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -14,23 +14,25 @@
# limitations under the License.
+import hashlib
+import logging
+from collections import namedtuple
+
+from six import iteritems, iterkeys, itervalues
+
+from frozendict import frozendict
+
from twisted.internet import defer
from synapse import event_auth
-from synapse.util.logutils import log_function
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR
-
-from collections import namedtuple
-from frozendict import frozendict
-
-import logging
-import hashlib
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logutils import log_function
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -58,7 +60,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ # dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
+
+ # the ID of a state group if one and only one is involved.
+ # otherwise, None otherwise?
self.state_group = state_group
self.prev_group = prev_group
@@ -81,31 +87,19 @@ class _StateCacheEntry(object):
class StateHandler(object):
- """ Responsible for doing state conflict resolution.
+ """Fetches bits of state from the stores, and does state resolution
+ where necessary
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
-
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
- self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ self._state_resolution_handler = hs.get_state_resolution_handler()
def start_caching(self):
- logger.debug("start_caching")
-
- self._state_cache = ExpiringCache(
- cache_name="state_cache",
- clock=self.clock,
- max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
- iterable=True,
- reset_expiry_on_get=True,
- )
-
- self._state_cache.start()
+ # TODO: remove this shim
+ self._state_resolution_handler.start_caching()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +121,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
@@ -138,27 +132,36 @@ class StateHandler(object):
defer.returnValue(event)
return
- state_map = yield self.store.get_events(state.values(), get_prev_content=False)
+ state_map = yield self.store.get_events(list(state.values()),
+ get_prev_content=False)
state = {
- key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+ key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_state_ids(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state_ids(self, room_id, latest_event_ids=None):
+ """Get the current state, or the state at a set of events, for a room
+
+ Args:
+ room_id (str):
+
+ latest_event_ids (iterable[str]|None): if given, the forward
+ extremities to resolve. If None, we look them up from the
+ database (via a cache)
+
+ Returns:
+ Deferred[dict[(str, str), str)]]: the state dict, mapping from
+ (event_type, state_key) -> event_id
+ """
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- if event_type:
- defer.returnValue(state.get((event_type, state_key)))
- return
-
defer.returnValue(state)
@defer.inlineCallbacks
@@ -166,7 +169,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@@ -175,7 +178,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts)
@@ -183,8 +186,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event.
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
+
Args:
event (synapse.events.EventBase):
+ old_state (dict|None): The state at the event if it can't be
+ calculated from existing events. This is normally only specified
+ when receiving an event from federation where we don't have the
+ prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
"""
@@ -193,113 +203,158 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
- context = EventContext()
if old_state:
- context.prev_state_ids = {
+ prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
- context.current_state_ids = dict(context.prev_state_ids)
+ current_state_ids = dict(prev_state_ids)
key = (event.type, event.state_key)
- context.current_state_ids[key] = event.event_id
+ current_state_ids[key] = event.event_id
else:
- context.current_state_ids = context.prev_state_ids
+ current_state_ids = prev_state_ids
else:
- context.current_state_ids = {}
- context.prev_state_ids = {}
- context.prev_state_events = []
- context.state_group = self.store.get_next_state_group()
+ current_state_ids = {}
+ prev_state_ids = {}
+
+ # We don't store state for outliers, so we don't generate a state
+ # group for it.
+ context = EventContext.with_state(
+ state_group=None,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ )
+
defer.returnValue(context)
if old_state:
- context = EventContext()
- context.prev_state_ids = {
+ # We already have the state, so we don't need to calculate it.
+ # Let's just correctly fill out the context and create a
+ # new state group for it.
+
+ prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
- context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
- if key in context.prev_state_ids:
- replaces = context.prev_state_ids[key]
+ if key in prev_state_ids:
+ replaces = prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
- context.current_state_ids = dict(context.prev_state_ids)
- context.current_state_ids[key] = event.event_id
+ current_state_ids = dict(prev_state_ids)
+ current_state_ids[key] = event.event_id
else:
- context.current_state_ids = context.prev_state_ids
+ current_state_ids = prev_state_ids
+
+ state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=None,
+ delta_ids=None,
+ current_state_ids=current_state_ids,
+ )
+
+ context = EventContext.with_state(
+ state_group=state_group,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ )
- context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
)
- curr_state = entry.state
+ prev_state_ids = entry.state
+ prev_group = None
+ delta_ids = None
- context = EventContext()
- context.prev_state_ids = curr_state
if event.is_state():
- context.state_group = self.store.get_next_state_group()
+ # If this is a state event then we need to create a new state
+ # group for the state after this event.
key = (event.type, event.state_key)
- if key in context.prev_state_ids:
- replaces = context.prev_state_ids[key]
+ if key in prev_state_ids:
+ replaces = prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
- context.current_state_ids = dict(context.prev_state_ids)
- context.current_state_ids[key] = event.event_id
+ current_state_ids = dict(prev_state_ids)
+ current_state_ids[key] = event.event_id
if entry.state_group:
- context.prev_group = entry.state_group
- context.delta_ids = {
+ # If the state at the event has a state group assigned then
+ # we can use that as the prev group
+ prev_group = entry.state_group
+ delta_ids = {
key: event.event_id
}
elif entry.prev_group:
- context.prev_group = entry.prev_group
- context.delta_ids = dict(entry.delta_ids)
- context.delta_ids[key] = event.event_id
+ # If the state at the event only has a prev group, then we can
+ # use that as a prev group too.
+ prev_group = entry.prev_group
+ delta_ids = dict(entry.delta_ids)
+ delta_ids[key] = event.event_id
+
+ state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ current_state_ids=current_state_ids,
+ )
else:
+ current_state_ids = prev_state_ids
+ prev_group = entry.prev_group
+ delta_ids = entry.delta_ids
+
if entry.state_group is None:
- entry.state_group = self.store.get_next_state_group()
+ entry.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=entry.prev_group,
+ delta_ids=entry.delta_ids,
+ current_state_ids=current_state_ids,
+ )
entry.state_id = entry.state_group
- context.state_group = entry.state_group
- context.current_state_ids = context.prev_state_ids
- context.prev_group = entry.prev_group
- context.delta_ids = entry.delta_ids
+ state_group = entry.state_group
+
+ context = EventContext.with_state(
+ state_group=state_group,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
- context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
- @log_function
- def resolve_state_groups(self, room_id, event_ids):
+ def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
Returns:
- a Deferred tuple of (`state_group`, `state`, `prev_state`).
- `state_group` is the name of a state group if one and only one is
- involved. `state` is a map from (type, state_key) to event, and
- `prev_state` is a list of event ids.
+ Deferred[_StateCacheEntry]: resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
+ # map from state group id to the state in that state group (where
+ # 'state' is a map from state key to event id)
+ # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids
)
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
-
- group_names = frozenset(state_groups_ids.keys())
- if len(group_names) == 1:
- name, state_list = state_groups_ids.items().pop()
+ if len(state_groups_ids) == 1:
+ name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -310,6 +365,102 @@ class StateHandler(object):
delta_ids=delta_ids,
))
+ result = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups_ids, None, self._state_map_factory,
+ )
+ defer.returnValue(result)
+
+ def _state_map_factory(self, ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+
+ def resolve_events(self, state_sets, event):
+ logger.info(
+ "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+ )
+ state_set_ids = [{
+ (ev.type, ev.state_key): ev.event_id
+ for ev in st
+ } for st in state_sets]
+
+ state_map = {
+ ev.event_id: ev
+ for st in state_sets
+ for ev in st
+ }
+
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+ new_state = {
+ key: state_map[ev_id] for key, ev_id in iteritems(new_state)
+ }
+
+ return new_state
+
+
+class StateResolutionHandler(object):
+ """Responsible for doing state conflict resolution.
+
+ Note that the storage layer depends on this handler, so all functions must
+ be storage-independent.
+ """
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+
+ # dict of set of event_ids -> _StateCacheEntry.
+ self._state_cache = None
+ self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+ def start_caching(self):
+ logger.debug("start_caching")
+
+ self._state_cache = ExpiringCache(
+ cache_name="state_cache",
+ clock=self.clock,
+ max_len=SIZE_OF_CACHE,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+ iterable=True,
+ reset_expiry_on_get=True,
+ )
+
+ self._state_cache.start()
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(
+ self, room_id, state_groups_ids, event_map, state_map_factory,
+ ):
+ """Resolves conflicts between a set of state groups
+
+ Always generates a new state group (unless we hit the cache), so should
+ not be called for a single state group
+
+ Args:
+ room_id (str): room we are resolving for (used for logging)
+ state_groups_ids (dict[int, dict[(str, str), str]]):
+ map from state group id to the state in that state group
+ (where 'state' is a map from state key to event id)
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ Returns:
+ Deferred[_StateCacheEntry]: resolved state
+ """
+ logger.debug(
+ "resolve_state_groups state_groups %s",
+ state_groups_ids.keys()
+ )
+
+ group_names = frozenset(state_groups_ids.keys())
+
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
@@ -320,112 +471,128 @@ class StateHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
- state = {}
- for st in state_groups_ids.values():
- for key, e_id in st.items():
- state.setdefault(key, set()).add(e_id)
-
- conflicted_state = {
- k: list(v)
- for k, v in state.items()
- if len(v) > 1
- }
+ # start by assuming we won't have any conflicted state, and build up the new
+ # state map by iterating through the state groups. If we discover a conflict,
+ # we give up and instead use `resolve_events_with_factory`.
+ #
+ # XXX: is this actually worthwhile, or should we just let
+ # resolve_events_with_factory do it?
+ new_state = {}
+ conflicted_state = False
+ for st in itervalues(state_groups_ids):
+ for key, e_id in iteritems(st):
+ if key in new_state:
+ conflicted_state = True
+ break
+ new_state[key] = e_id
+ if conflicted_state:
+ break
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events(
- state_groups_ids.values(),
- state_map_factory=lambda ev_ids: self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
+ new_state = yield resolve_events_with_factory(
+ list(itervalues(state_groups_ids)),
+ event_map=event_map,
+ state_map_factory=state_map_factory,
)
- else:
- new_state = {
- key: e_ids.pop() for key, e_ids in state.items()
- }
- state_group = None
- new_state_event_ids = frozenset(new_state.values())
- for sg, events in state_groups_ids.items():
- if new_state_event_ids == frozenset(e_id for e_id in events):
- state_group = sg
- break
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
- # TODO: We want to create a state group for this set of events, to
- # increase cache hits, but we need to make sure that it doesn't
- # end up as a prev_group without being added to the database
-
- prev_group = None
- delta_ids = None
- for old_group, old_ids in state_groups_ids.iteritems():
- if not set(new_state) - set(old_ids):
- n_delta_ids = {
- k: v
- for k, v in new_state.iteritems()
- if old_ids.get(k) != v
- }
- if not delta_ids or len(n_delta_ids) < len(delta_ids):
- prev_group = old_group
- delta_ids = n_delta_ids
-
- cache = _StateCacheEntry(
- state=new_state,
- state_group=state_group,
- prev_group=prev_group,
- delta_ids=delta_ids,
- )
+ with Measure(self.clock, "state.create_group_ids"):
+ cache = _make_state_cache_entry(new_state, state_groups_ids)
if self._state_cache is not None:
self._state_cache[group_names] = cache
defer.returnValue(cache)
- def resolve_events(self, state_sets, event):
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
+def _make_state_cache_entry(
+ new_state,
+ state_groups_ids,
+):
+ """Given a resolved state, and a set of input state groups, pick one to base
+ a new state group on (if any), and return an appropriately-constructed
+ _StateCacheEntry.
- with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events(state_set_ids, state_map)
+ Args:
+ new_state (dict[(str, str), str]): resolved state map (mapping from
+ (type, state_key) to event_id)
- new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.items()
- }
+ state_groups_ids (dict[int, dict[(str, str), str]]):
+ map from state group id to the state in that state group
+ (where 'state' is a map from state key to event id)
- return new_state
+ Returns:
+ _StateCacheEntry
+ """
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
+
+ # first look for exact matches
+ new_state_event_ids = set(itervalues(new_state))
+ for sg, state in iteritems(state_groups_ids):
+ if len(new_state_event_ids) != len(state):
+ continue
+
+ old_state_event_ids = set(itervalues(state))
+ if new_state_event_ids == old_state_event_ids:
+ # got an exact match.
+ return _StateCacheEntry(
+ state=new_state,
+ state_group=sg,
+ )
+
+ # TODO: We want to create a state group for this set of events, to
+ # increase cache hits, but we need to make sure that it doesn't
+ # end up as a prev_group without being added to the database
+
+ # failing that, look for the closest match.
+ prev_group = None
+ delta_ids = None
+
+ for old_group, old_state in iteritems(state_groups_ids):
+ n_delta_ids = {
+ k: v
+ for k, v in iteritems(new_state)
+ if old_state.get(k) != v
+ }
+ if not delta_ids or len(n_delta_ids) < len(delta_ids):
+ prev_group = old_group
+ delta_ids = n_delta_ids
+
+ return _StateCacheEntry(
+ state=new_state,
+ state_group=None,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
def _ordered_events(events):
def key_func(e):
- return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func)
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- state_map_factory(dict|callable): If callable, then will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event. Otherwise, should be
- a dict from event_id to event of all events in state_sets.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent] is a map from
- (type, state_key) to event.
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -434,13 +601,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
- if callable(state_map_factory):
- return _resolve_with_state_fac(
- unconflicted_state, conflicted_state, state_map_factory
- )
-
- state_map = state_map_factory
-
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -454,12 +614,28 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
+
+ Args:
+ state_sets(iterable[dict[(str, str), str]]):
+ List of dicts of (type, state_key) -> event_id, which are the
+ different state groups to resolve.
+
+ Returns:
+ (dict[(str, str), str], dict[(str, str), set[str]]):
+ A tuple of (unconflicted_state, conflicted_state), where:
+
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
+
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
- unconflicted_state = dict(state_sets[0])
+ state_set_iterator = iter(state_sets)
+ unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
- for state_set in state_sets[1:]:
- for key, value in state_set.iteritems():
+ for state_set in state_set_iterator:
+ for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -484,24 +660,63 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
- state_map_factory):
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
needed_events = set(
event_id
- for event_ids in conflicted_state.itervalues()
+ for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
+ if event_map is not None:
+ needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
+ # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+ # the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
+ if event_map is not None:
+ state_map.update(event_map)
+ # get the ids of the auth events which allow us to authenticate the
+ # conflicted state, picking only from the unconflicting state.
+ #
+ # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(auth_events.itervalues())
+ new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
+ if event_map is not None:
+ new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))
@@ -515,7 +730,7 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
- for event_ids in conflicted_state.itervalues():
+ for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -527,10 +742,10 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
return auth_events
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
+def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
- for key, event_ids in conflicted_state_ds.iteritems():
+ for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -539,7 +754,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = {
key: state_map[ev_id]
- for key, ev_id in auth_event_ids.items()
+ for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
@@ -547,12 +762,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
- except:
+ except Exception:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
- for key, event in resolved_state.iteritems():
+ for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
@@ -577,7 +792,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -587,7 +802,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -597,7 +812,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b92472df33..ba88a54979 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,53 +14,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import datetime
+import logging
+import time
+
+from dateutil import tz
+from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
-from .appservice import (
- ApplicationServiceStore, ApplicationServiceTransactionStore
-)
-from ._base import LoggingTransaction
+from synapse.storage.user_erasure_store import UserErasureStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
from .directory import DirectoryStore
+from .end_to_end_keys import EndToEndKeyStore
+from .engines import PostgresEngine
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
from .events import EventsStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
from .registration import RegistrationStore
+from .rejections import RejectionsStore
from .room import RoomStore
from .roommember import RoomMemberStore
-from .stream import StreamStore
-from .transactions import TransactionStore
-from .keys import KeyStore
-from .event_federation import EventFederationStore
-from .pusher import PusherStore
-from .push_rule import PushRuleStore
-from .media_repository import MediaRepositoryStore
-from .rejections import RejectionsStore
-from .event_push_actions import EventPushActionsStore
-from .deviceinbox import DeviceInboxStore
-
-from .state import StateStore
-from .signatures import SignatureStore
-from .filtering import FilteringStore
-from .end_to_end_keys import EndToEndKeyStore
-
-from .receipts import ReceiptsStore
from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stream import StreamStore
from .tags import TagsStore
-from .account_data import AccountDataStore
-from .openid import OpenIdStore
-from .client_ips import ClientIpStore
+from .transactions import TransactionStore
from .user_directory import UserDirectoryStore
-
-from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
-from .engines import PostgresEngine
-
-from synapse.api.constants import PresenceState
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-
-import logging
-
+from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
logger = logging.getLogger(__name__)
@@ -88,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
+ GroupServerStore,
+ UserErasureStore,
):
def __init__(self, db_conn, hs):
@@ -103,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
- )
- self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -123,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
- self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -135,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
+ self._group_updates_id_gen = StreamIdGenerator(
+ db_conn, "local_group_updates", "stream_id",
+ )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
@@ -143,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore,
else:
self._cache_id_gen = None
- events_max = self._stream_id_gen.get_current_token()
- event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
- entity_column="room_id",
- stream_column="stream_ordering",
- max_value=events_max,
- )
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
- prefilled_cache=event_cache_prefill,
- )
-
- self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
- )
-
- account_max = self._account_data_id_gen.get_current_token()
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max,
- )
-
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -177,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill
)
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn, "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._push_rules_stream_id_gen.get_current_token()[0],
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache", push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
@@ -223,6 +185,7 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
+ events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
entity_column="room_id",
@@ -235,24 +198,25 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=curr_state_delta_prefill,
)
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- after_callbacks=[],
- final_callbacks=[],
+ _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+ db_conn, "local_group_updates",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._group_updates_id_gen.get_current_token(),
+ limit=1000,
)
- self._find_stream_orderings_for_times_txn(cur)
- cur.close()
-
- self.find_stream_orderings_looping_call = self._clock.looping_call(
- self._find_stream_orderings_for_times, 10 * 60 * 1000
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", min_group_updates_id,
+ prefilled_cache=_group_updates_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- super(DataStore, self).__init__(hs)
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+
+ super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
@@ -281,13 +245,12 @@ class DataStore(RoomMemberStore, RoomStore,
return [UserPresenceState(**row) for row in rows]
- @defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
sql = """
SELECT COALESCE(count(*), 0) FROM (
@@ -301,8 +264,154 @@ class DataStore(RoomMemberStore, RoomStore,
count, = txn.fetchone()
return count
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return self.runInteraction("count_users", _count_users)
+
+ def count_r30_users(self):
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns counts globaly for a given user as well as breaking
+ by platform
+ """
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] is 'unknown':
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ count, = txn.fetchone()
+ results['all'] = count
+
+ return results
+
+ return self.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = datetime.datetime.utcnow()
+ today_start = datetime.datetime(now.year, now.month,
+ now.day, tzinfo=tz.tzutc())
+ return int(time.mktime(today_start.timetuple())) * 1000
+
+ def generate_user_daily_visits(self):
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self.clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(sql, (
+ yesterday_start, yesterday_start,
+ self._last_user_visit_update, today_start
+ ))
+ self._last_user_visit_update = today_start
+
+ txn.execute(sql, (
+ today_start, today_start,
+ self._last_user_visit_update,
+ now
+ ))
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ return self.runInteraction("generate_user_daily_visits",
+ _generate_user_daily_visits)
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6f54036d67..1d41d8d445 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -13,36 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import sys
+import threading
+import time
-from synapse.api.errors import StoreError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import CACHE_SIZE_FACTOR
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.caches.descriptors import Cache
-from synapse.storage.engines import PostgresEngine
-import synapse.metrics
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+from prometheus_client import Histogram
from twisted.internet import defer
-import sys
-import time
-import threading
-
+from synapse.api.errors import StoreError
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import Cache
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
+try:
+ MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+ # python 3 does not have a maximum int value
+ MAX_TXN_ID = 2**63 - 1
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-metrics = synapse.metrics.get_metrics_for("synapse.storage")
-
-sql_scheduling_timer = metrics.register_distribution("schedule_time")
-
-sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
-sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
class LoggingTransaction(object):
@@ -50,16 +52,16 @@ class LoggingTransaction(object):
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks,
- final_callbacks):
+ exception_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "final_callbacks", final_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -68,8 +70,8 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
- def call_finally(self, callback, *args, **kwargs):
- self.final_callbacks.append((callback, args, kwargs))
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -103,11 +105,11 @@ class LoggingTransaction(object):
"[SQL values] {%s} %r",
self.name, args[0]
)
- except:
+ except Exception:
# Don't let logging failures stop SQL from working
pass
- start = time.time() * 1000
+ start = time.time()
try:
return func(
@@ -117,9 +119,9 @@ class LoggingTransaction(object):
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
- msecs = (time.time() * 1000) - start
- sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
- sql_query_timer.inc_by(msecs, sql.split()[0])
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
class PerformanceCounters(object):
@@ -129,7 +131,7 @@ class PerformanceCounters(object):
def update(self, key, start_time, end_time=None):
if end_time is None:
- end_time = time.time() * 1000
+ end_time = time.time()
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
@@ -139,7 +141,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.iteritems():
+ for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -162,7 +164,7 @@ class PerformanceCounters(object):
class SQLBaseStore(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
@@ -180,10 +182,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
- )
-
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -221,14 +219,14 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
- logging_context, func, *args, **kwargs):
- start = time.time() * 1000
+ def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
+ func, *args, **kwargs):
+ start = time.time()
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
@@ -242,7 +240,7 @@ class SQLBaseStore(object):
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks,
- final_callbacks,
+ exception_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -283,73 +281,85 @@ class SQLBaseStore(object):
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
- end = time.time() * 1000
+ end = time.time()
duration = end - start
- if logging_context is not None:
- logging_context.add_database_transaction(duration)
+ LoggingContext.current_context().add_database_transaction(duration)
- transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.inc_by(duration, desc)
+ sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
-
- start_time = time.time() * 1000
+ """Starts a transaction on the database and runs a given function
- after_callbacks = []
- final_callbacks = []
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ Returns:
+ Deferred: The result of func
+ """
+ after_callbacks = []
+ exception_callbacks = []
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(
+ self._new_transaction,
+ desc, after_callbacks, exception_callbacks, func,
+ *args, **kwargs
+ )
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
- finally:
- for after_callback, after_args, after_kwargs in final_callbacks:
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
+ raise
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Wraps the .runWithConnection() method on the underlying db_pool.
- start_time = time.time() * 1000
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ parent_context = LoggingContext.current_context()
+ if parent_context == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db connection from sentinel context: metrics will be lost",
+ )
+ parent_context = None
+
+ start_time = time.time()
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ with LoggingContext("runWithConnection", parent_context) as context:
+ sched_duration_sec = time.time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- current_context.copy_to(context)
-
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
@@ -368,7 +378,7 @@ class SQLBaseStore(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(intern(column[0]) for column in cursor.description)
+ col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor
)
@@ -475,23 +485,53 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
+ @defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
"""
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
- insertion_values (dict): key/values to use when inserting
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
"""
- return self.runInteraction(
- desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock
- )
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+ lock=lock
+ )
+ defer.returnValue(result)
+ except self.database_engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warn(
+ "IntegrityError when upserting into %s; retrying: %s",
+ table, e
+ )
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
@@ -499,37 +539,38 @@ class SQLBaseStore(object):
if lock:
self.database_engine.lock_table(txn, table)
- # Try to update
+ # First try to update.
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- sqlargs = values.values() + keyvalues.values()
+ sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs)
- if txn.rowcount == 0:
- # We didn't update and rows so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
- )
- txn.execute(sql, allvalues.values())
+ # We didn't update any rows so insert a new one
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
- return True
- else:
- return False
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues)
+ )
+ txn.execute(sql, list(allvalues.values()))
+ # successfully inserted
+ return True
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
+ return a single row, returning multiple columns from it.
Args:
table : string giving the table name
@@ -582,20 +623,18 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- else:
- where = ""
-
sql = (
- "SELECT %(retcol)s FROM %(table)s %(where)s"
+ "SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
- "where": where,
}
- txn.execute(sql, keyvalues.values())
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ txn.execute(sql)
return [r[0] for r in txn]
@@ -606,7 +645,7 @@ class SQLBaseStore(object):
Args:
table (str): table name
- keyvalues (dict): column names and values to select the rows with
+ keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:
@@ -657,7 +696,7 @@ class SQLBaseStore(object):
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
@@ -688,9 +727,12 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
chunks = [
- iterable[i:i + batch_size]
- for i in xrange(0, len(iterable), batch_size)
+ it_list[i:i + batch_size]
+ for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
@@ -730,7 +772,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -743,6 +785,33 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
+ def _simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc,
+ self._simple_update_txn,
+ table, keyvalues, updatevalues,
+ )
+
+ @staticmethod
+ def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(
+ update_sql,
+ list(updatevalues.values()) + list(keyvalues.values())
+ )
+
+ return txn.rowcount
+
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
@@ -768,27 +837,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- @staticmethod
- def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
- )
+ @classmethod
+ def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
- if txn.rowcount == 0:
+ if rowcount == 0:
raise StoreError(404, "No row found")
- if txn.rowcount > 1:
+ if rowcount > 1:
raise StoreError(500, "More than one row matched")
@staticmethod
@@ -800,7 +855,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- txn.execute(select_sql, keyvalues.values())
+ txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
@@ -838,7 +893,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
@@ -856,7 +911,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- return txn.execute(sql, keyvalues.values())
+ return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -888,7 +943,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -928,7 +983,7 @@ class SQLBaseStore(object):
txn.close()
if cache:
- min_val = min(cache.itervalues())
+ min_val = min(itervalues(cache))
else:
min_val = max_value
@@ -951,7 +1006,8 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_finally(ctx.__exit__, None, None, None)
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
@@ -1042,7 +1098,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
- txn.execute(sql, keyvalues.values() + pagevalues)
+ txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index aa84ffc2b0..bbc3355c73 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer
+import abc
+import logging
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from canonicaljson import json
-import ujson as json
-import logging
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_account_data_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ account_max = self.get_max_account_data_stream_id()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+ @abc.abstractmethod
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream ID for account data stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2)
+ @cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
"""
Returns:
@@ -85,25 +114,7 @@ class AccountDataStore(SQLBaseStore):
else:
defer.returnValue(None)
- @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
- num_args=2, list_name="user_ids", inlineCallbacks=True)
- def get_global_account_data_by_type_for_users(self, data_type, user_ids):
- rows = yield self._simple_select_many_batch(
- table="account_data",
- column="user_id",
- iterable=user_ids,
- keyvalues={
- "account_data_type": data_type,
- },
- retcols=("user_id", "content",),
- desc="get_global_account_data_by_type_for_users",
- )
-
- defer.returnValue({
- row["user_id"]: json.loads(row["content"]) if row["content"] else None
- for row in rows
- })
-
+ @cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
@@ -127,6 +138,38 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
+ @cached(num_args=3, max_entries=5000)
+ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ """Get the client account_data of given type for a user for a room.
+
+ Args:
+ user_id(str): The user to get the account_data for.
+ room_id(str): The room to get the account_data for.
+ account_data_type (str): The account data type to get.
+ Returns:
+ A deferred of the room account_data for that type, or None if
+ there isn't any set.
+ """
+ def get_account_data_for_room_and_type_txn(txn):
+ content_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={
+ "user_id": user_id,
+ "room_id": room_id,
+ "account_data_type": account_data_type,
+ },
+ retcol="content",
+ allow_none=True
+ )
+
+ return json.loads(content_json) if content_json else None
+
+ return self.runInteraction(
+ "get_account_data_for_room_and_type",
+ get_account_data_for_room_and_type_txn,
+ )
+
def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit):
"""Get all the client account_data that has changed on the server
@@ -209,6 +252,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
+ @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+ ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", ignorer_user_id,
+ on_invalidate=cache_context.invalidate,
+ )
+ if not ignored_account_data:
+ defer.returnValue(False)
+
+ defer.returnValue(
+ ignored_user_id in ignored_account_data.get("ignored_users", {})
+ )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ def __init__(self, db_conn, hs):
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ super(AccountDataStore, self).__init__(db_conn, hs)
+
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream id for the private user data stream
+
+ Returns:
+ A deferred int.
+ """
+ return self._account_data_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@@ -222,9 +295,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
- def add_account_data_txn(txn, next_id):
- self._simple_upsert_txn(
- txn,
+ with self._account_data_id_gen.get_next() as next_id:
+ # no need to lock here as room_account_data has a unique constraint
+ # on (user_id, room_id, account_data_type) so _simple_upsert will
+ # retry if there is a conflict.
+ yield self._simple_upsert(
+ desc="add_room_account_data",
table="room_account_data",
keyvalues={
"user_id": user_id,
@@ -234,18 +310,23 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
- }
- )
- txn.call_after(
- self._account_data_stream_cache.entity_has_changed,
- user_id, next_id,
+ },
+ lock=False,
)
- txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
- self._update_max_stream_id(txn, next_id)
- with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "add_room_account_data", add_account_data_txn, next_id
+ # it's theoretically possible for the above to succeed and the
+ # below to fail - in which case we might reuse a stream id on
+ # restart, and the above update might not get propagated. That
+ # doesn't sound any worse than the whole update getting lost,
+ # which is what would happen if we combined the two into one
+ # transaction.
+ yield self._update_max_stream_id(next_id)
+
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id,))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type,), content,
)
result = self._account_data_id_gen.get_current_token()
@@ -263,9 +344,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
- def add_account_data_txn(txn, next_id):
- self._simple_upsert_txn(
- txn,
+ with self._account_data_id_gen.get_next() as next_id:
+ # no need to lock here as account_data has a unique constraint on
+ # (user_id, account_data_type) so _simple_upsert will retry if
+ # there is a conflict.
+ yield self._simple_upsert(
+ desc="add_user_account_data",
table="account_data",
keyvalues={
"user_id": user_id,
@@ -274,37 +358,43 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
- }
+ },
+ lock=False,
)
- txn.call_after(
- self._account_data_stream_cache.entity_has_changed,
+
+ # it's theoretically possible for the above to succeed and the
+ # below to fail - in which case we might reuse a stream id on
+ # restart, and the above update might not get propagated. That
+ # doesn't sound any worse than the whole update getting lost,
+ # which is what would happen if we combined the two into one
+ # transaction.
+ yield self._update_max_stream_id(next_id)
+
+ self._account_data_stream_cache.entity_has_changed(
user_id, next_id,
)
- txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
- txn.call_after(
- self.get_global_account_data_by_type_for_user.invalidate,
+ self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id,)
)
- self._update_max_stream_id(txn, next_id)
-
- with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "add_user_account_data", add_account_data_txn, next_id
- )
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
- def _update_max_stream_id(self, txn, next_id):
+ def _update_max_stream_id(self, next_id):
"""Update the max stream_id
Args:
- txn: The database cursor
next_id(int): The the revision to advance to.
"""
- update_max_id_sql = (
- "UPDATE account_data_max_stream_id"
- " SET stream_id = ?"
- " WHERE stream_id < ?"
+ def _update(txn):
+ update_max_id_sql = (
+ "UPDATE account_data_max_stream_id"
+ " SET stream_id = ?"
+ " WHERE stream_id < ?"
+ )
+ txn.execute(update_max_id_sql, (next_id, next_id))
+ return self.runInteraction(
+ "update_account_data_max_stream_id",
+ _update,
)
- txn.execute(update_max_id_sql, (next_id, next_id))
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index c63935cb07..9f12b360bc 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,15 +15,16 @@
# limitations under the License.
import logging
import re
-import simplejson as json
+
+from canonicaljson import json
+
from twisted.internet import defer
-from synapse.api.constants import Membership
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -46,17 +48,16 @@ def _make_exclusive_regex(services_cache):
return exclusive_user_regex
-class ApplicationServiceStore(SQLBaseStore):
-
- def __init__(self, hs):
- super(ApplicationServiceStore, self).__init__(hs)
- self.hostname = hs.hostname
+class ApplicationServiceWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname,
hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+
def get_app_services(self):
return self.services_cache
@@ -99,83 +100,30 @@ class ApplicationServiceStore(SQLBaseStore):
return service
return None
- def get_app_service_rooms(self, service):
- """Get a list of RoomsForUser for this application service.
-
- Application services may be "interested" in lots of rooms depending on
- the room ID, the room aliases, or the members in the room. This function
- takes all of these into account and returns a list of RoomsForUser which
- represent the entire list of room IDs that this application service
- wants to know about.
+ def get_app_service_by_id(self, as_id):
+ """Get the application service with the given appservice ID.
Args:
- service: The application service to get a room list for.
+ as_id (str): The application service ID.
Returns:
- A list of RoomsForUser.
+ synapse.appservice.ApplicationService or None.
"""
- return self.runInteraction(
- "get_app_service_rooms",
- self._get_app_service_rooms_txn,
- service,
- )
-
- def _get_app_service_rooms_txn(self, txn, service):
- # get all rooms matching the room ID regex.
- room_entries = self._simple_select_list_txn(
- txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
- )
- matching_room_list = set([
- r["room_id"] for r in room_entries if
- service.is_interested_in_room(r["room_id"])
- ])
-
- # resolve room IDs for matching room alias regex.
- room_alias_mappings = self._simple_select_list_txn(
- txn=txn, table="room_aliases", keyvalues=None,
- retcols=["room_id", "room_alias"]
- )
- matching_room_list |= set([
- r["room_id"] for r in room_alias_mappings if
- service.is_interested_in_alias(r["room_alias"])
- ])
-
- # get all rooms for every user for this AS. This is scoped to users on
- # this HS only.
- user_list = self._simple_select_list_txn(
- txn=txn, table="users", keyvalues=None, retcols=["name"]
- )
- user_list = [
- u["name"] for u in user_list if
- service.is_interested_in_user(u["name"])
- ]
- rooms_for_user_matching_user_id = set() # RoomsForUser list
- for user_id in user_list:
- # FIXME: This assumes this store is linked with RoomMemberStore :(
- rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
- txn=txn,
- user_id=user_id,
- membership_list=[Membership.JOIN]
- )
- rooms_for_user_matching_user_id |= set(rooms_for_user)
-
- # make RoomsForUser tuples for room ids and aliases which are not in the
- # main rooms_for_user_list - e.g. they are rooms which do not have AS
- # registered users in it.
- known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
- missing_rooms_for_user = [
- RoomsForUser(r, service.sender, "join") for r in
- matching_room_list if r not in known_room_ids
- ]
- rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
-
- return rooms_for_user_matching_user_id
+ for service in self.services_cache:
+ if service.id == as_id:
+ return service
+ return None
-class ApplicationServiceTransactionStore(SQLBaseStore):
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
- def __init__(self, hs):
- super(ApplicationServiceTransactionStore, self).__init__(hs)
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+ EventsWorkerStore):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@@ -420,3 +368,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 7157fb1dfb..5fe1ca2de7 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,15 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.async
-from ._base import SQLBaseStore
-from . import engines
+import logging
+
+from canonicaljson import json
from twisted.internet import defer
-import ujson as json
-import logging
+from synapse.metrics.background_process_metrics import run_as_background_process
+
+from . import engines
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -80,25 +82,30 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs):
- super(BackgroundUpdateStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(BackgroundUpdateStore, self).__init__(db_conn, hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
+ self._all_done = False
- @defer.inlineCallbacks
def start_doing_background_updates(self):
- logger.info("Starting background schema updates")
+ run_as_background_process(
+ "background_updates", self._run_background_updates,
+ )
+ @defer.inlineCallbacks
+ def _run_background_updates(self):
+ logger.info("Starting background schema updates")
while True:
- yield synapse.util.async.sleep(
+ yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
- except:
+ except Exception:
logger.exception("Error doing update")
else:
if result is None:
@@ -106,9 +113,41 @@ class BackgroundUpdateStore(SQLBaseStore):
"No more background updates to do."
" Unscheduling background update task."
)
+ self._all_done = True
defer.returnValue(None)
@defer.inlineCallbacks
+ def has_completed_background_updates(self):
+ """Check if all the background updates have completed
+
+ Returns:
+ Deferred[bool]: True if all background updates have completed
+ """
+ # if we've previously determined that there is nothing left to do, that
+ # is easy
+ if self._all_done:
+ defer.returnValue(True)
+
+ # obviously, if we have things in our queue, we're not done.
+ if self._background_update_queue:
+ defer.returnValue(False)
+
+ # otherwise, check if there are updates to be run. This is important,
+ # as we may be running on a worker which doesn't perform the bg updates
+ # itself, but still wants to wait for them to happen.
+ updates = yield self._simple_select_onecol(
+ "background_updates",
+ keyvalues=None,
+ retcol="1",
+ desc="check_background_updates",
+ )
+ if not updates:
+ self._all_done = True
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
@@ -209,6 +248,25 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_noop_background_update(self, update_name):
+ """Register a noop handler for a background update.
+
+ This is useful when we previously did a background update, but no
+ longer wish to do the update. In this case the background update should
+ be removed from the schema delta files, but there may still be some
+ users who have the background update queued, so this method should
+ also be called to clear the update.
+
+ Args:
+ update_name (str): Name of update
+ """
+ @defer.inlineCallbacks
+ def noop_update(progress, batch_size):
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, noop_update)
+
def register_background_index_update(self, update_name, index_name,
table, columns, where_clause=None,
unique=False,
@@ -269,7 +327,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
- # until 3.8, and wheezy has 3.7
+ # until 3.8, and wheezy and CentOS 7 have 3.7
#
# We assume that sqlite doesn't give us invalid indices; however
# we may still end up with the index existing but the
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index fc468ea185..77ae10da3d 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,13 +15,15 @@
import logging
-from twisted.internet import defer, reactor
+from six import iteritems
-from ._base import Cache
-from . import background_updates
+from twisted.internet import defer
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR
+from . import background_updates
+from ._base import Cache
logger = logging.getLogger(__name__)
@@ -32,14 +34,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
- def __init__(self, hs):
+ def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR,
)
- super(ClientIpStore, self).__init__(hs)
+ super(ClientIpStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",
@@ -48,17 +50,35 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
+ self.register_background_index_update(
+ "user_ips_last_seen_index",
+ index_name="user_ips_last_seen",
+ table="user_ips",
+ columns=["user_id", "last_seen"],
+ )
+
+ self.register_background_index_update(
+ "user_ips_last_seen_only_index",
+ index_name="user_ips_last_seen_only",
+ table="user_ips",
+ columns=["last_seen"],
+ )
+
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
- reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
- def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
- now = int(self._clock.time_msec())
- key = (user.to_string(), access_token, ip)
+ def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
+ now=None):
+ if not now:
+ now = int(self._clock.time_msec())
+ key = (user_id, access_token, ip)
try:
last_seen = self.client_ip_last_seen.get(key)
@@ -74,16 +94,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self):
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ def update():
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn,
+ to_update,
+ )
+
+ run_as_background_process(
+ "update_client_ips", update,
)
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
- for entry in to_update.iteritems():
+ for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
@@ -215,5 +241,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 0b62b493d5..73646da025 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,14 +14,14 @@
# limitations under the License.
import logging
-import ujson
-from twisted.internet import defer
+from canonicaljson import json
-from .background_updates import BackgroundUpdateStore
+from twisted.internet import defer
from synapse.util.caches.expiringcache import ExpiringCache
+from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, hs):
- super(DeviceInboxStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",
@@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = ujson.dumps(edu)
+ edu_json = json.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
- message_json = ujson.dumps(messages_by_device["*"])
+ message_json = json.dumps(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = ujson.dumps(messages_by_device[device])
+ message_json = json.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
@@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
@@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bb27fd1f70..cc3cdf2ebc 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,21 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import ujson as json
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Cache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from ._base import Cache, SQLBaseStore
logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
- def __init__(self, hs):
- super(DeviceStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -245,17 +248,31 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
- self._simple_upsert_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "content": json.dumps(content),
- }
- )
+ if content.get("deleted"):
+ self._simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ )
+
+ txn.call_after(
+ self.device_id_exists_cache.invalidate, (user_id, device_id,)
+ )
+ else:
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "content": json.dumps(content),
+ }
+ )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -360,10 +377,10 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, [])
if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
@@ -373,13 +390,13 @@ class DeviceStore(SQLBaseStore):
"""
results = []
- for user_id, user_devices in devices.iteritems():
+ for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -390,12 +407,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = json.loads(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = json.loads(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
results.append(result)
@@ -483,7 +503,7 @@ class DeviceStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 79e7c540ad..808194236a 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-
-from synapse.api.errors import SynapseError
+from collections import namedtuple
from twisted.internet import defer
-from collections import namedtuple
+from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
+from ._base import SQLBaseStore
RoomAliasMapping = namedtuple(
"RoomAliasMapping",
@@ -29,8 +28,7 @@ RoomAliasMapping = namedtuple(
)
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias
@@ -69,6 +67,28 @@ class DirectoryStore(SQLBaseStore):
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
+ def get_room_alias_creator(self, room_alias):
+ return self._simple_select_one_onecol(
+ table="room_aliases",
+ keyvalues={
+ "room_alias": room_alias,
+ },
+ retcol="creator",
+ desc="get_room_alias_creator",
+ allow_none=True
+ )
+
+ @cached(max_entries=5000)
+ def get_aliases_for_room(self, room_id):
+ return self._simple_select_onecol(
+ "room_aliases",
+ {"room_id": room_id},
+ "room_alias",
+ desc="get_aliases_for_room",
+ )
+
+
+class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers
@@ -116,17 +136,6 @@ class DirectoryStore(SQLBaseStore):
)
defer.returnValue(ret)
- def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
- table="room_aliases",
- keyvalues={
- "room_alias": room_alias,
- },
- retcol="creator",
- desc="get_room_alias_creator",
- allow_none=True
- )
-
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
@@ -135,7 +144,6 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,17 +168,12 @@ class DirectoryStore(SQLBaseStore):
(room_alias.to_string(),)
)
- return room_id
-
- @cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
- "room_aliases",
- {"room_id": room_id},
- "room_alias",
- desc="get_aliases_for_room",
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (room_id,)
)
+ return room_id
+
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2cebb203c6..523b4360c3 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from canonicaljson import encode_canonical_json
-import ujson as json
-
from ._base import SQLBaseStore
@@ -63,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_e2e_device_keys(self, query_list, include_all_devices=False):
+ def get_e2e_device_keys(
+ self, query_list, include_all_devices=False,
+ include_deleted_devices=False,
+ ):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
+ include_deleted_devices (bool): whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
@@ -78,19 +85,28 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
- query_list, include_all_devices,
+ query_list, include_all_devices, include_deleted_devices,
)
- for user_id, device_keys in results.iteritems():
- for device_id, device_info in device_keys.iteritems():
+ for user_id, device_keys in iteritems(results):
+ for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
- def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False,
+ include_deleted_devices=False,
+ ):
query_clauses = []
query_params = []
+ if include_all_devices is False:
+ include_deleted_devices = False
+
+ if include_deleted_devices:
+ deleted_devices = set(query_list)
+
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
@@ -118,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
result = {}
for row in rows:
+ if include_deleted_devices:
+ deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ if include_deleted_devices:
+ for user_id, device_id in deleted_devices:
+ result.setdefault(user_id, {})[device_id] = None
+
return result
@defer.inlineCallbacks
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..e2f9de8451 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
+import platform
+
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
-import importlib
-
-
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
@@ -31,6 +31,10 @@ def create_engine(database_config):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
+ # pypy requires psycopg2cffi rather than psycopg2
+ if (name == "psycopg2" and
+ platform.python_implementation() == "PyPy"):
+ name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ txn.execute("SELECT nextval('state_group_id_seq')")
+ return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..19949fc474 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import prepare_database
-
import struct
+import threading
+
+from synapse.storage.prepare_database import prepare_database
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ # The current max state_group, or None if we haven't looked
+ # in the DB yet.
+ self._current_state_group_id = None
+ self._current_state_group_id_lock = threading.Lock()
+
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._current_state_group_id_lock:
+ if self._current_state_group_id is None:
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ self._current_state_group_id = txn.fetchone()[0]
+
+ self._current_state_group_id += 1
+ return self._current_state_group_id
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index e8133de2fa..8d366d1b91 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,45 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import random
-from twisted.internet import defer
+from six.moves import range
+from six.moves.queue import Empty, PriorityQueue
-from ._base import SQLBaseStore
-from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64
-import logging
-from Queue import PriorityQueue, Empty
+from twisted.internet import defer
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-class EventFederationStore(SQLBaseStore):
- """ Responsible for storing and serving up the various graphs associated
- with an event. Including the main event graph and the auth chains for an
- event.
-
- Also has methods for getting the front (latest) and back (oldest) edges
- of the event graphs. These are used to generate the parents for new events
- and backfilling from another server respectively.
- """
-
- EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
-
- def __init__(self, hs):
- super(EventFederationStore, self).__init__(hs)
-
- self.register_background_update_handler(
- self.EVENT_AUTH_STATE_ONLY,
- self._background_delete_non_state_event_auth,
- )
-
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
-
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+ SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@@ -97,7 +79,7 @@ class EventFederationStore(SQLBaseStore):
front_list = list(front)
chunks = [
front_list[x:x + 100]
- for x in xrange(0, len(front), 100)
+ for x in range(0, len(front), 100)
]
for chunk in chunks:
txn.execute(
@@ -152,7 +134,47 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id",
)
+ @defer.inlineCallbacks
+ def get_prev_events_for_room(self, room_id):
+ """
+ Gets a subset of the current forward extremities in the given room.
+
+ Limits the result to 10 extremities, so that we can avoid creating
+ events which refer to hundreds of prev_events.
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+ res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
+ if len(res) > 10:
+ # Sort by reverse depth, so we point to the most recent.
+ res.sort(key=lambda a: -a[2])
+
+ # we use half of the limit for the actual most recent events, and
+ # the other half to randomly point to some of the older events, to
+ # make sure that we don't completely ignore the older events.
+ res = res[0:5] + random.sample(res[5:], 5)
+
+ defer.returnValue(res)
+
def get_latest_event_ids_and_hashes_in_room(self, room_id):
+ """
+ Gets the current forward extremities in the given room
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+
return self.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
@@ -201,22 +223,6 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_max_depth_of_events(self, event_ids):
- sql = (
- "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
- ) % (",".join(["?"] * len(event_ids)),)
-
- rows = yield self._execute(
- "get_max_depth_of_events", None,
- sql, *event_ids
- )
-
- if rows:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(1)
-
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
@@ -228,88 +234,6 @@ class EventFederationStore(SQLBaseStore):
return int(min_depth) if min_depth is not None else None
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
- min_depth = self._get_min_depth_interaction(txn, room_id)
-
- if min_depth and depth >= min_depth:
- return
-
- self._simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
- )
-
- def _handle_mult_prev_events(self, txn, events):
- """
- For the given event, update the event edges table and forward and
- backward extremities tables.
- """
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
- for ev in events
- for e_id, _ in ev.prev_events
- ],
- )
-
- self._update_backward_extremeties(txn, events)
-
- def _update_backward_extremeties(self, txn, events):
- """Updates the event_backward_extremities tables based on the new/updated
- events being persisted.
-
- This is called for new events *and* for events that were outliers, but
- are now being persisted as non-outliers.
-
- Forward extremities are handled when we first start persisting the events.
- """
- events_by_room = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
-
- txn.executemany(query, [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events for e_id, _ in ev.prev_events
- if not ev.internal_metadata.is_outlier()
- ])
-
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.executemany(
- query,
- [
- (ev.event_id, ev.room_id) for ev in events
- if not ev.internal_metadata.is_outlier()
- ]
- )
-
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -371,28 +295,6 @@ class EventFederationStore(SQLBaseStore):
get_forward_extremeties_for_room_txn
)
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = ("""
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """)
- txn.execute(
- sql,
- (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
- )
- return self.runInteraction(
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn
- )
-
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -522,6 +424,135 @@ class EventFederationStore(SQLBaseStore):
return event_results
+
+class EventFederationStore(EventFederationWorkerStore):
+ """ Responsible for storing and serving up the various graphs associated
+ with an event. Including the main event graph and the auth chains for an
+ event.
+
+ Also has methods for getting the front (latest) and back (oldest) edges
+ of the event graphs. These are used to generate the parents for new events
+ and backfilling from another server respectively.
+ """
+
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+ def __init__(self, db_conn, hs):
+ super(EventFederationStore, self).__init__(db_conn, hs)
+
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ if min_depth and depth >= min_depth:
+ return
+
+ self._simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={
+ "room_id": room_id,
+ },
+ values={
+ "min_depth": depth,
+ },
+ )
+
+ def _handle_mult_prev_events(self, txn, events):
+ """
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": ev.event_id,
+ "prev_event_id": e_id,
+ "room_id": ev.room_id,
+ "is_state": False,
+ }
+ for ev in events
+ for e_id, _ in ev.prev_events
+ ],
+ )
+
+ self._update_backward_extremeties(txn, events)
+
+ def _update_backward_extremeties(self, txn, events):
+ """Updates the event_backward_extremities tables based on the new/updated
+ events being persisted.
+
+ This is called for new events *and* for events that were outliers, but
+ are now being persisted as non-outliers.
+
+ Forward extremities are handled when we first start persisting the events.
+ """
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
+
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
+
+ txn.executemany(query, [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ])
+
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [
+ (ev.event_id, ev.room_id) for ev in events
+ if not ev.internal_metadata.is_outlier()
+ ]
+ )
+
+ def _delete_old_forward_extrem_cache(self):
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = ("""
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """)
+ txn.execute(
+ sql,
+ (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+ )
+ return self.runInteraction(
+ "_delete_old_forward_extrem_cache",
+ _delete_old_forward_extrem_cache_txn
+ )
+
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d6d8723b4a..29b511ae5e 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+import logging
+
+from six import iteritems
+
+from canonicaljson import json
+
from twisted.internet import defer
-from synapse.util.async import sleep
-from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.types import RoomStreamToken
-from .stream import lower_bound
-import logging
-import ujson as json
+from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -62,59 +64,29 @@ def _deserialize_action(actions, is_highlight):
return DEFAULT_NOTIF_ACTION
-class EventPushActionsStore(SQLBaseStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
- def __init__(self, hs):
- super(EventPushActionsStore, self).__init__(hs)
+class EventPushActionsWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
- self.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
- )
+ # These get correctly set by _find_stream_orderings_for_times_txn
+ self.stream_ordering_month_ago = None
+ self.stream_ordering_day_ago = None
- self.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1"
+ cur = LoggingTransaction(
+ db_conn.cursor(),
+ name="_find_stream_orderings_for_times_txn",
+ database_engine=self.database_engine,
+ after_callbacks=[],
+ exception_callbacks=[],
)
+ self._find_stream_orderings_for_times_txn(cur)
+ cur.close()
- self._doing_notif_rotation = False
- self._rotate_notif_loop = self._clock.looping_call(
- self._rotate_notifs, 30 * 60 * 1000
+ self.find_stream_orderings_looping_call = self._clock.looping_call(
+ self._find_stream_orderings_for_times, 10 * 60 * 1000
)
-
- def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
- """
- Args:
- event: the event set actions for
- tuples: list of tuples of (user_id, actions)
- """
- values = []
- for uid, actions in tuples:
- is_highlight = 1 if _action_has_highlight(actions) else 0
-
- values.append({
- 'room_id': event.room_id,
- 'event_id': event.event_id,
- 'user_id': uid,
- 'actions': _serialize_action(actions, is_highlight),
- 'stream_ordering': event.internal_metadata.stream_ordering,
- 'topological_ordering': event.depth,
- 'notif': 1,
- 'highlight': is_highlight,
- })
-
- for uid, __ in tuples:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid)
- )
- self._simple_insert_many_txn(txn, "event_push_actions", values)
+ self._rotate_delay = 3
+ self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@@ -130,7 +102,7 @@ class EventPushActionsStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
last_read_event_id):
sql = (
- "SELECT stream_ordering, topological_ordering"
+ "SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
@@ -142,17 +114,12 @@ class EventPushActionsStore(SQLBaseStore):
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
- topological_ordering = results[0][1]
return self._get_unread_counts_by_pos_txn(
- txn, room_id, user_id, topological_ordering, stream_ordering
+ txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
- stream_ordering):
- token = RoomStreamToken(
- topological_ordering, stream_ordering
- )
+ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
@@ -163,10 +130,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
notify_count = row[0] if row else 0
@@ -186,10 +153,10 @@ class EventPushActionsStore(SQLBaseStore):
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
highlight_count = row[0] if row else 0
@@ -240,7 +207,6 @@ class EventPushActionsStore(SQLBaseStore):
" ep.highlight "
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -250,13 +216,7 @@ class EventPushActionsStore(SQLBaseStore):
" event_push_actions AS ep"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -349,7 +309,6 @@ class EventPushActionsStore(SQLBaseStore):
" ep.highlight, e.received_ts"
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -360,13 +319,7 @@ class EventPushActionsStore(SQLBaseStore):
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -432,6 +385,290 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit`
defer.returnValue(notifs[:limit])
+ def add_push_actions_to_staging(self, event_id, user_id_actions):
+ """Add the push actions for the event to the push action staging area.
+
+ Args:
+ event_id (str)
+ user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+ user_id to list of push actions, where an action can either be
+ a string or dict.
+
+ Returns:
+ Deferred
+ """
+
+ if not user_id_actions:
+ return
+
+ # This is a helper function for generating the necessary tuple that
+ # can be used to inert into the `event_push_actions_staging` table.
+ def _gen_entry(user_id, actions):
+ is_highlight = 1 if _action_has_highlight(actions) else 0
+ return (
+ event_id, # event_id column
+ user_id, # user_id column
+ _serialize_action(actions, is_highlight), # actions column
+ 1, # notif column
+ is_highlight, # highlight column
+ )
+
+ def _add_push_actions_to_staging_txn(txn):
+ # We don't use _simple_insert_many here to avoid the overhead
+ # of generating lists of dicts.
+
+ sql = """
+ INSERT INTO event_push_actions_staging
+ (event_id, user_id, actions, notif, highlight)
+ VALUES (?, ?, ?, ?, ?)
+ """
+
+ txn.executemany(sql, (
+ _gen_entry(user_id, actions)
+ for user_id, actions in iteritems(user_id_actions)
+ ))
+
+ return self.runInteraction(
+ "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ )
+
+ @defer.inlineCallbacks
+ def remove_push_actions_from_staging(self, event_id):
+ """Called if we failed to persist the event to ensure that stale push
+ actions don't build up in the DB
+
+ Args:
+ event_id (str)
+ """
+
+ try:
+ res = yield self._simple_delete(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event_id,
+ },
+ desc="remove_push_actions_from_staging",
+ )
+ defer.returnValue(res)
+ except Exception:
+ # this method is called from an exception handler, so propagating
+ # another exception here really isn't helpful - there's nothing
+ # the caller can do about it. Just log the exception and move on.
+ logger.exception(
+ "Error removing push actions after event persistence failure",
+ )
+
+ @defer.inlineCallbacks
+ def _find_stream_orderings_for_times(self):
+ yield self.runInteraction(
+ "_find_stream_orderings_for_times",
+ self._find_stream_orderings_for_times_txn
+ )
+
+ def _find_stream_orderings_for_times_txn(self, txn):
+ logger.info("Searching for stream ordering 1 month ago")
+ self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 month ago: it's %d",
+ self.stream_ordering_month_ago
+ )
+ logger.info("Searching for stream ordering 1 day ago")
+ self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 day ago: it's %d",
+ self.stream_ordering_day_ago
+ )
+
+ def find_first_stream_ordering_after_ts(self, ts):
+ """Gets the stream ordering corresponding to a given timestamp.
+
+ Specifically, finds the stream_ordering of the first event that was
+ received on or after the timestamp. This is done by a binary search on
+ the events table, since there is no index on received_ts, so is
+ relatively slow.
+
+ Args:
+ ts (int): timestamp in millis
+
+ Returns:
+ Deferred[int]: stream ordering of the first event received on/after
+ the timestamp
+ """
+ return self.runInteraction(
+ "_find_first_stream_ordering_after_ts_txn",
+ self._find_first_stream_ordering_after_ts_txn,
+ ts,
+ )
+
+ @staticmethod
+ def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ """
+ Find the stream_ordering of the first event that was received on or
+ after a given timestamp. This is relatively slow as there is no index
+ on received_ts but we can then use this to delete push actions before
+ this.
+
+ received_ts must necessarily be in the same order as stream_ordering
+ and stream_ordering is indexed, so we manually binary search using
+ stream_ordering
+
+ Args:
+ txn (twisted.enterprise.adbapi.Transaction):
+ ts (int): timestamp to search for
+
+ Returns:
+ int: stream ordering
+ """
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ max_stream_ordering = txn.fetchone()[0]
+
+ if max_stream_ordering is None:
+ return 0
+
+ # We want the first stream_ordering in which received_ts is greater
+ # than or equal to ts. Call this point X.
+ #
+ # We maintain the invariants:
+ #
+ # range_start <= X <= range_end
+ #
+ range_start = 0
+ range_end = max_stream_ordering + 1
+
+ # Given a stream_ordering, look up the timestamp at that
+ # stream_ordering.
+ #
+ # The array may be sparse (we may be missing some stream_orderings).
+ # We treat the gaps as the same as having the same value as the
+ # preceding entry, because we will pick the lowest stream_ordering
+ # which satisfies our requirement of received_ts >= ts.
+ #
+ # For example, if our array of events indexed by stream_ordering is
+ # [10, <none>, 20], we should treat this as being equivalent to
+ # [10, 10, 20].
+ #
+ sql = (
+ "SELECT received_ts FROM events"
+ " WHERE stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT 1"
+ )
+
+ while range_end - range_start > 0:
+ middle = (range_end + range_start) // 2
+ txn.execute(sql, (middle,))
+ row = txn.fetchone()
+ if row is None:
+ # no rows with stream_ordering<=middle
+ range_start = middle + 1
+ continue
+
+ middle_ts = row[0]
+ if ts > middle_ts:
+ # we got a timestamp lower than the one we were looking for.
+ # definitely need to look higher: X > middle.
+ range_start = middle + 1
+ else:
+ # we got a timestamp higher than (or the same as) the one we
+ # were looking for. We aren't yet sure about the point we
+ # looked up, but we can be sure that X <= middle.
+ range_end = middle
+
+ return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1"
+ )
+
+ self._doing_notif_rotation = False
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
+
+ def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+ all_events_and_contexts):
+ """Handles moving push actions from staging table to main
+ event_push_actions table for all events in `events_and_contexts`.
+
+ Also ensures that all events in `all_events_and_contexts` are removed
+ from the push action staging area.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ """
+
+ sql = """
+ INSERT INTO event_push_actions (
+ room_id, event_id, user_id, actions, stream_ordering,
+ topological_ordering, notif, highlight
+ )
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ FROM event_push_actions_staging
+ WHERE event_id = ?
+ """
+
+ if events_and_contexts:
+ txn.executemany(sql, (
+ (
+ event.room_id, event.internal_metadata.stream_ordering,
+ event.depth, event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ))
+
+ for event, _ in events_and_contexts:
+ user_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event.event_id,
+ },
+ retcol="user_id",
+ )
+
+ for uid in user_ids:
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid,)
+ )
+
+ # Now we delete the staging area for *all* events that were being
+ # persisted.
+ txn.executemany(
+ "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+ (
+ (event.event_id,)
+ for event, _ in all_events_and_contexts
+ )
+ )
+
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False):
@@ -509,10 +746,10 @@ class EventPushActionsStore(SQLBaseStore):
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- topological_ordering, stream_ordering):
+ stream_ordering):
"""
Purges old push actions for a user and room before a given
- topological_ordering.
+ stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
@@ -521,7 +758,7 @@ class EventPushActionsStore(SQLBaseStore):
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
- topological_ordering: The lowest topological ordering which will
+ stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
@@ -540,9 +777,9 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
- " topological_ordering <= ?"
+ " stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
)
txn.execute("""
@@ -551,69 +788,6 @@ class EventPushActionsStore(SQLBaseStore):
""", (room_id, user_id, stream_ordering))
@defer.inlineCallbacks
- def _find_stream_orderings_for_times(self):
- yield self.runInteraction(
- "_find_stream_orderings_for_times",
- self._find_stream_orderings_for_times_txn
- )
-
- def _find_stream_orderings_for_times_txn(self, txn):
- logger.info("Searching for stream ordering 1 month ago")
- self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
- txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
- )
- logger.info(
- "Found stream ordering 1 month ago: it's %d",
- self.stream_ordering_month_ago
- )
- logger.info("Searching for stream ordering 1 day ago")
- self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
- txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
- )
- logger.info(
- "Found stream ordering 1 day ago: it's %d",
- self.stream_ordering_day_ago
- )
-
- def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
- """
- Find the stream_ordering of the first event that was received after
- a given timestamp. This is relatively slow as there is no index on
- received_ts but we can then use this to delete push actions before
- this.
-
- received_ts must necessarily be in the same order as stream_ordering
- and stream_ordering is indexed, so we manually binary search using
- stream_ordering
- """
- txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
-
- if max_stream_ordering is None:
- return 0
-
- range_start = 0
- range_end = max_stream_ordering
-
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering > ?"
- " ORDER BY stream_ordering"
- " LIMIT 1"
- )
-
- while range_end - range_start > 1:
- middle = int((range_end + range_start) / 2)
- txn.execute(sql, (middle,))
- middle_ts = txn.fetchone()[0]
- if ts > middle_ts:
- range_start = middle
- else:
- range_end = middle
-
- return range_end
-
- @defer.inlineCallbacks
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
@@ -629,7 +803,7 @@ class EventPushActionsStore(SQLBaseStore):
)
if caught_up:
break
- yield sleep(5)
+ yield self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False
@@ -650,8 +824,8 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute("""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
- ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
- """, (old_rotate_stream_ordering,))
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+ """, (old_rotate_stream_ordering, self._rotate_count))
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7002b3752e..906a405031 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,64 +13,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+import itertools
+import logging
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+from six import iteritems
+from six.moves import range
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
-)
-from synapse.util.logutils import log_function
-from synapse.util.metrics import Measure
-from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
-from synapse.util.caches.descriptors import cached
-from synapse.types import get_domain_from_id
+from canonicaljson import json
+from prometheus_client import Counter
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
+from twisted.internet import defer
import synapse.metrics
-
-import logging
-import ujson as json
-
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util.async import ObservableDeferred
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
+from synapse.util.logutils import log_function
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
+ ["type", "origin_type", "origin_entity"])
-metrics = synapse.metrics.get_metrics_for(__name__)
-persist_event_counter = metrics.register_counter("persisted_events")
-event_counter = metrics.register_counter(
- "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
-)
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", "")
-def encode_json(json_object):
- if USE_FROZEN_DICTS:
- # ujson doesn't like frozen_dicts
- return encode_canonical_json(json_object)
- else:
- return json.dumps(json_object, ensure_ascii=False)
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", "")
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+def encode_json(json_object):
+ return frozendict_json_encoder.encode(json_object)
class _EventPeristenceQueue(object):
@@ -88,19 +84,29 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
Args:
room_id (str):
events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
end_item = queue[-1]
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred())
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -113,11 +119,11 @@ class _EventPeristenceQueue(object):
def handle_queue(self, room_id, per_item_callback):
"""Attempts to handle the queue for a room if not already being handled.
- The given callback will be invoked with for each item in the queue,1
+ The given callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return
value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferres as well.
+ exceptions will be passed to the deferreds as well.
This function should therefore be called whenever anything is added
to the queue.
@@ -136,18 +142,23 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
+ # handle_queue_loop runs in the sentinel logcontext, so
+ # there is no need to preserve_fn when running the
+ # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
- item.deferred.callback(ret)
- except Exception as e:
- item.deferred.errback(e)
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ except Exception:
+ item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- preserve_fn(handle_queue_loop)()
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -183,13 +194,12 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
- def __init__(self, hs):
- super(EventsStore, self).__init__(hs)
- self._clock = hs.get_clock()
+ def __init__(self, db_conn, hs):
+ super(EventsStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@@ -220,6 +230,8 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
@@ -232,8 +244,8 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in partitioned.iteritems():
- d = preserve_fn(self._event_persist_queue.add_to_queue)(
+ for room_id, evs_ctxs in iteritems(partitioned):
+ d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs,
backfilled=backfilled,
)
@@ -242,7 +254,7 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- return preserve_context_over_deferred(
+ return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
@@ -267,7 +279,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id)
- yield preserve_context_over_deferred(deferred)
+ yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -275,10 +287,11 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
- yield self._persist_events(
- item.events_and_contexts,
- backfilled=item.backfilled,
- )
+ with Measure(self._clock, "persist_events"):
+ yield self._persist_events(
+ item.events_and_contexts,
+ backfilled=item.backfilled,
+ )
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@@ -316,7 +329,7 @@ class EventsStore(SQLBaseStore):
chunks = [
events_and_contexts[x:x + 100]
- for x in xrange(0, len(events_and_contexts), 100)
+ for x in range(0, len(events_and_contexts), 100)
]
for chunk in chunks:
@@ -325,8 +338,23 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room
# at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -338,7 +366,7 @@ class EventsStore(SQLBaseStore):
(event, context)
)
- for room_id, ev_ctx_rm in events_by_room.iteritems():
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -348,7 +376,8 @@ class EventsStore(SQLBaseStore):
room_id, ev_ctx_rm, latest_event_ids
)
- if new_latest_event_ids == set(latest_event_ids):
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
# No change in extremities, so no change in state
continue
@@ -369,11 +398,63 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(e for e, _ in ev.prev_events)
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.info(
+ "Calculating state delta for room %s", room_id,
)
- if state:
- current_state_for_room[room_id] = state
+ with Measure(
+ self._clock,
+ "persist_events.get_new_state_after_events",
+ ):
+ res = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ state_delta_for_room[room_id] = ([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock,
+ "persist_events.calculate_state_delta",
+ ):
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
yield self.runInteraction(
"persist_events",
@@ -381,10 +462,13 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
- current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
- persist_event_counter.inc_by(len(chunk))
+ persist_event_counter.inc(len(chunk))
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
for event, context in chunk:
if context.app_service:
origin_type = "local"
@@ -396,14 +480,14 @@ class EventsStore(SQLBaseStore):
origin_type = "remote"
origin_entity = get_domain_from_id(event.sender)
- event_counter.inc(event.type, origin_type, origin_entity)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+ for room_id, new_state in iteritems(current_state_for_room):
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
- for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
self.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -450,183 +534,187 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert. `new_state` is the full set of state.
- May return None if there are no changes to be applied.
+ Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
"""
- # Now we need to work out the different state sets for
- # each state extremities
- state_sets = []
- state_groups = set()
- missing_event_ids = []
- was_updated = False
+
+ if not new_latest_event_ids:
+ return
+
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
-
- # If we've already seen the state group don't bother adding
- # it to the state sets again
- if ctx.state_group not in state_groups:
- state_sets.append(ctx.current_state_ids)
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
- if ctx.state_group:
- # Add this as a seen state group (if it has a state
- # group)
- state_groups.add(ctx.state_group)
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
- groups = set(event_to_groups.itervalues()) - state_groups
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
- if groups:
- group_to_state = yield self._get_state_for_groups(groups)
- state_sets.extend(group_to_state.itervalues())
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- if not new_latest_event_ids:
- current_state = {}
- elif was_updated:
- if len(state_sets) == 1:
- # If there is only one state set, then we know what the current
- # state is.
- current_state = state_sets[0]
- else:
- # We work out the current state by passing the state sets to the
- # state resolution algorithm. It may ask for some events, including
- # the events we have yet to persist, so we need a slightly more
- # complicated event lookup function than simply looking the events
- # up in the db.
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- @defer.inlineCallbacks
- def get_events(ev_ids):
- # We get the events by first looking at the list of events we
- # are trying to persist, and then fetching the rest from the DB.
- db = []
- to_return = {}
- for ev_id in ev_ids:
- ev = events_map.get(ev_id, None)
- if ev:
- to_return[ev_id] = ev
- else:
- db.append(ev_id)
-
- if db:
- evs = yield self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- )
- to_return.update(evs)
- defer.returnValue(to_return)
-
- current_state = yield resolve_events(
- state_sets,
- state_map_factory=get_events,
- )
- else:
- return
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ defer.returnValue((None, None))
- existing_state = yield self.get_current_state_ids(room_id)
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
- existing_events = set(existing_state.itervalues())
- new_events = set(ev_id for ev_id in current_state.itervalues())
- changed_events = existing_events ^ new_events
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
- if not changed_events:
- return
+ delta_ids = state_group_deltas.get(
+ (old_state_group, new_state_group,), None
+ )
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ defer.returnValue((new_state, delta_ids))
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ def get_events(ev_ids):
+ return self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
- to_delete = {
- key: ev_id for key, ev_id in existing_state.iteritems()
- if ev_id in changed_events
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
}
- events_to_insert = (new_events - existing_events)
- to_insert = {
- key: ev_id for key, ev_id in current_state.iteritems()
- if ev_id in events_to_insert
- }
-
- defer.returnValue((to_delete, to_insert, current_state))
-
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
- Returns:
- Deferred : A FrozenEvent.
- """
- events = yield self._get_events(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups, events_map, get_events
)
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
-
- defer.returnValue(events[0] if events else None)
+ defer.returnValue((res.state, None))
@defer.inlineCallbacks
- def get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- """Get events from the database
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ Assumes that we are only persisting events for one room at a time.
Returns:
- Deferred : Dict from event_id to event.
+ tuple[list, dict] (to_delete, to_insert): where to_delete are the
+ type/state_keys to remove from current_state_events and `to_insert`
+ are the updates to current_state_events.
"""
- events = yield self._get_events(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ existing_state = yield self.get_current_state_ids(room_id)
+
+ to_delete = [
+ key for key in existing_state
+ if key not in current_state
+ ]
+
+ to_insert = {
+ key: ev_id for key, ev_id in iteritems(current_state)
+ if ev_id != existing_state.get(key)
+ }
- defer.returnValue({e.event_id: e for e in events})
+ defer.returnValue((to_delete, to_insert))
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
- delete_existing=False, current_state_for_room={},
+ delete_existing=False, state_delta_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
@@ -642,19 +730,21 @@ class EventsStore(SQLBaseStore):
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
- current_state_for_room (dict[str, (list[str], list[str])]):
+ state_delta_for_room (dict[str, (list, dict)]):
The current-state delta for each room. For each room, a tuple
- (to_delete, to_insert), being a list of event ids to be removed
- from the current state, and a list of event ids to be added to
+ (to_delete, to_insert), being a list of type/state keys to be
+ removed from the current state, and a state set to be added to
the current state.
new_forward_extremeties (dict[str, list[str]]):
The new forward extremities for each room. For each room, a
list of the event ids which are the forward extremities.
"""
+ all_events_and_contexts = events_and_contexts
+
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+ self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_forward_extremities_txn(
txn,
@@ -698,9 +788,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts,
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
+ # Insert into event_to_state_groups.
+ self._store_event_state_mappings_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -715,15 +804,53 @@ class EventsStore(SQLBaseStore):
self._update_metadata_tables_txn(
txn,
events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
backfilled=backfilled,
)
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
- for room_id, current_state_tuple in state_delta_by_room.iteritems():
- to_delete, to_insert, _ = current_state_tuple
+ for room_id, current_state_tuple in iteritems(state_delta_by_room):
+ to_delete, to_insert = current_state_tuple
+
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
+ )
+ """
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, None,
+ room_id, etype, state_key,
+ )
+ for etype, state_key in to_delete
+ # We sanity check that we're deleting rather than updating
+ if (etype, state_key) not in to_insert
+ ))
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, ev_id,
+ room_id, etype, state_key,
+ )
+ for (etype, state_key), ev_id in iteritems(to_insert)
+ ))
+
+ # Now we actually update the current_state_events table
+
txn.executemany(
- "DELETE FROM current_state_events WHERE event_id = ?",
- [(ev_id,) for ev_id in to_delete.itervalues()],
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
)
self._simple_insert_many_txn(
@@ -736,30 +863,12 @@ class EventsStore(SQLBaseStore):
"type": key[0],
"state_key": key[1],
}
- for key, ev_id in to_insert.iteritems()
+ for key, ev_id in iteritems(to_insert)
],
)
- state_deltas = {key: None for key in to_delete}
- state_deltas.update(to_insert)
-
- self._simple_insert_many_txn(
- txn,
- table="current_state_delta_stream",
- values=[
- {
- "stream_id": max_stream_order,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": ev_id,
- "prev_event_id": to_delete.get(key, None),
- }
- for key, ev_id in state_deltas.iteritems()
- ]
- )
-
- self._curr_state_delta_stream_cache.entity_has_changed(
+ txn.call_after(
+ self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order,
)
@@ -771,19 +880,23 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in state_deltas
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
- txn, self.get_rooms_for_user, (member,)
+ txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host)
)
+ self._invalidate_cache_and_stream(
+ txn, self.was_host_joined, (room_id, host)
+ )
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
@@ -795,7 +908,7 @@ class EventsStore(SQLBaseStore):
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
- for room_id, new_extrem in new_forward_extremities.iteritems():
+ for room_id, new_extrem in iteritems(new_forward_extremities):
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@@ -813,7 +926,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id,
"room_id": room_id,
}
- for room_id, new_extrem in new_forward_extremities.iteritems()
+ for room_id, new_extrem in iteritems(new_forward_extremities)
for ev_id in new_extrem
],
)
@@ -830,7 +943,7 @@ class EventsStore(SQLBaseStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in new_forward_extremities.iteritems()
+ for room_id, new_extrem in iteritems(new_forward_extremities)
for event_id in new_extrem
]
)
@@ -858,7 +971,7 @@ class EventsStore(SQLBaseStore):
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
- return new_events_and_contexts.values()
+ return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
"""Update min_depth for each room
@@ -884,7 +997,7 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in depth_updates.iteritems():
+ for room_id, depth in iteritems(depth_updates):
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@@ -932,10 +1045,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
- # insert into the state_group, state_groups_state and
- # event_to_state_groups tables.
+ # insert into event_to_state_groups.
try:
- self._store_mult_state_groups_txn(txn, ((event, context),))
+ self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
@@ -1001,7 +1113,6 @@ class EventsStore(SQLBaseStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
@@ -1021,6 +1132,14 @@ class EventsStore(SQLBaseStore):
[(ev.event_id,) for ev, _ in events_and_contexts]
)
+ for table in (
+ "event_push_actions",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
@@ -1110,27 +1229,33 @@ class EventsStore(SQLBaseStore):
ec for ec in events_and_contexts if ec[0] not in to_remove
]
- def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+ def _update_metadata_tables_txn(self, txn, events_and_contexts,
+ all_events_and_contexts, backfilled):
"""Update all the miscellaneous tables for new events
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
backfilled (bool): True if the events were backfilled
"""
+ # Insert all the push actions into the event_push_actions table.
+ self._set_push_actions_for_event_and_users_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ )
+
if not events_and_contexts:
# nothing to do here
return
for event, context in events_and_contexts:
- # Insert all the push actions into the event_push_actions table.
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
@@ -1263,7 +1388,7 @@ class EventsStore(SQLBaseStore):
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(ev_map)),)
- txn.execute(sql, ev_map.keys())
+ txn.execute(sql, list(ev_map))
rows = self.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
@@ -1302,13 +1427,49 @@ class EventsStore(SQLBaseStore):
defer.returnValue(set(r["event_id"] for r in rows))
- def have_events(self, event_ids):
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
+ Args:
+ event_ids (iterable[str]):
+
Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
"""
if not event_ids:
return defer.succeed({})
@@ -1330,295 +1491,7 @@ class EventsStore(SQLBaseStore):
return res
- return self.runInteraction(
- "have_events", f,
- )
-
- @defer.inlineCallbacks
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids,
- allow_rejected=allow_rejected,
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- allow_rejected=allow_rejected,
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
-
- Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
-
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
- """
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get(
- (event_id,), None,
- update_metrics=update_metrics,
- )
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- event_list = []
- i = 0
- while True:
- try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
- event_ids = [
- item for sublist in event_id_lists for item in sublist
- ]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
- )
-
- row_dict = {
- r["event_id"]: r
- for r in rows
- }
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
- except:
- logger.exception("Failed to callback")
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(e)
-
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append(
- (events, events_d)
- )
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self._get_event_from_row)(
- row["internal_metadata"], row["json"], row["redacts"],
- rejected_reason=row["rejects"],
- )
- for row in rows
- ],
- consumeErrors=True
- ))
-
- defer.returnValue({
- e.event.event_id: e
- for e in res if e
- })
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) / N):
- evs = events[i * N:(i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- original_ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
+ return self.runInteraction("get_rejection_reasons", f)
@defer.inlineCallbacks
def count_daily_messages(self):
@@ -1778,7 +1651,7 @@ class EventsStore(SQLBaseStore):
chunks = [
event_ids[i:i + 100]
- for i in xrange(0, len(event_ids), 100)
+ for i in range(0, len(event_ids), 100)
]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
@@ -2005,15 +1878,32 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
- def delete_old_state(self, room_id, topological_ordering):
+ def purge_history(
+ self, room_id, token, delete_local_events,
+ ):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+ """
+
return self.runInteraction(
- "delete_old_state",
- self._delete_old_state_txn, room_id, topological_ordering
+ "purge_history",
+ self._purge_history_txn, room_id, token,
+ delete_local_events,
)
- def _delete_old_state_txn(self, txn, room_id, topological_ordering):
- """Deletes old room state
- """
+ def _purge_history_txn(
+ self, txn, room_id, token_str, delete_local_events,
+ ):
+ token = RoomStreamToken.parse(token_str)
# Tables that should be pruned:
# event_auth
@@ -2035,6 +1925,37 @@ class EventsStore(SQLBaseStore):
# state_groups
# state_groups_state
+ # we will build a temporary table listing the events so that we don't
+ # have to keep shovelling the list back and forth across the
+ # connection. Annoyingly the python sqlite driver commits the
+ # transaction on CREATE, so let's do this first.
+ #
+ # furthermore, we might already have the table from a previous (failed)
+ # purge attempt, so let's drop the table first.
+
+ txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+ txn.execute(
+ "CREATE TEMPORARY TABLE events_to_purge ("
+ " event_id TEXT NOT NULL,"
+ " should_delete BOOLEAN NOT NULL"
+ ")"
+ )
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -2047,7 +1968,7 @@ class EventsStore(SQLBaseStore):
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
- if max_depth <= topological_ordering:
+ if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -2055,42 +1976,48 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties"
)
- logger.debug("[purge] looking for events to delete")
+ logger.info("[purge] looking for events to delete")
+
+ should_delete_expr = "state_key IS NULL"
+ should_delete_params = ()
+ if not delete_local_events:
+ should_delete_expr += " AND event_id NOT LIKE ?"
+ should_delete_params += ("%:" + self.hs.hostname, )
+
+ should_delete_params += (room_id, token.topological)
txn.execute(
- "SELECT event_id, state_key FROM events"
- " LEFT JOIN state_events USING (room_id, event_id)"
- " WHERE room_id = ? AND topological_ordering < ?",
- (room_id, topological_ordering,)
+ "INSERT INTO events_to_purge"
+ " SELECT event_id, %s"
+ " FROM events AS e LEFT JOIN state_events USING (event_id)"
+ " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ should_delete_expr,
+ ),
+ should_delete_params,
+ )
+ txn.execute(
+ "SELECT event_id, should_delete FROM events_to_purge"
)
event_rows = txn.fetchall()
-
- to_delete = [
- (event_id,) for event_id, state_key in event_rows
- if state_key is None and not self.hs.is_mine_id(event_id)
- ]
logger.info(
- "[purge] found %i events before cutoff, of which %i are remote"
- " non-state events to delete", len(event_rows), len(to_delete))
-
- for event_id, state_key in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+ "[purge] found %i events before cutoff, of which %i can be deleted",
+ len(event_rows), sum(1 for e in event_rows if e[1]),
+ )
- logger.debug("[purge] Finding new backward extremities")
+ logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
- # all events that point to events that are to be purged
+ # events to be purged that are pointed to by events we're not going to
+ # purge.
txn.execute(
- "SELECT DISTINCT e.event_id FROM events as e"
- " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
- " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
- " WHERE e.room_id = ? AND e.topological_ordering < ?"
- " AND e2.topological_ordering >= ?",
- (room_id, topological_ordering, topological_ordering)
+ "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+ " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+ " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+ " WHERE ep2.event_id IS NULL",
)
new_backwards_extrems = txn.fetchall()
- logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+ logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2106,34 +2033,39 @@ class EventsStore(SQLBaseStore):
]
)
- logger.debug("[purge] finding redundant state groups")
+ logger.info("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are
# to be deleted.
- txn.execute(
- "SELECT state_group FROM event_to_state_groups"
- " INNER JOIN events USING (event_id)"
- " WHERE state_group IN ("
- " SELECT DISTINCT state_group FROM events"
- " INNER JOIN event_to_state_groups USING (event_id)"
- " WHERE room_id = ? AND topological_ordering < ?"
- " )"
- " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
- (room_id, topological_ordering, topological_ordering)
- )
+ # This works by first getting state groups that we may want to delete,
+ # joining against event_to_state_groups to get events that use that
+ # state group, then left joining against events_to_purge again. Any
+ # state group where the left join produce *no nulls* are referenced
+ # only by events that are going to be purged.
+ txn.execute("""
+ SELECT state_group FROM
+ (
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
+ ) AS sp
+ INNER JOIN event_to_state_groups USING (state_group)
+ LEFT JOIN events_to_purge AS ep USING (event_id)
+ GROUP BY state_group
+ HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+ """)
state_rows = txn.fetchall()
- logger.debug("[purge] found %i redundant state groups", len(state_rows))
+ logger.info("[purge] found %i redundant state groups", len(state_rows))
# make a set of the redundant state groups, so that we can look them up
# efficiently
state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups
- logger.debug("[purge] finding state groups which depend on redundant"
- " state groups")
+ logger.info("[purge] finding state groups which depend on redundant"
+ " state groups")
remaining_state_groups = []
- for i in xrange(0, len(state_rows), 100):
+ for i in range(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]]
# look for state groups whose prev_state_group is one we are about
# to delete
@@ -2156,7 +2088,7 @@ class EventsStore(SQLBaseStore):
# Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions.
for sg in remaining_state_groups:
- logger.debug("[purge] de-delta-ing remaining state group %s", sg)
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None
)
@@ -2189,11 +2121,11 @@ class EventsStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in curr_state.iteritems()
+ for key, state_id in iteritems(curr_state)
],
)
- logger.debug("[purge] removing redundant state groups")
+ logger.info("[purge] removing redundant state groups")
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
state_rows
@@ -2203,18 +2135,15 @@ class EventsStore(SQLBaseStore):
state_rows
)
- # Delete all non-state
- logger.debug("[purge] removing events from event_to_state_groups")
- txn.executemany(
- "DELETE FROM event_to_state_groups WHERE event_id = ?",
- [(event_id,) for event_id, _ in event_rows]
- )
-
- logger.debug("[purge] updating room_depth")
+ logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
- "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (topological_ordering, room_id,)
+ "DELETE FROM event_to_state_groups "
+ "WHERE event_id IN (SELECT event_id from events_to_purge)"
)
+ for event_id, _ in event_rows:
+ txn.call_after(self._get_state_group_for_event.invalidate, (
+ event_id,
+ ))
# Delete all remote non-state events
for table in (
@@ -2226,28 +2155,75 @@ class EventsStore(SQLBaseStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
"rejections",
):
- logger.debug("[purge] removing remote non-state events from %s", table)
+ logger.info("[purge] removing events from %s", table)
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- to_delete
+ txn.execute(
+ "DELETE FROM %s WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ )
+
+ # event_push_actions lacks an index on event_id, and has one on
+ # (room_id, event_id) instead.
+ for table in (
+ "event_push_actions",
+ ):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ (room_id, )
)
# Mark all state and own events as outliers
- logger.debug("[purge] marking remaining events as outliers")
- txn.executemany(
+ logger.info("[purge] marking remaining events as outliers")
+ txn.execute(
"UPDATE events SET outlier = ?"
- " WHERE event_id = ?",
- [
- (True, event_id,) for event_id, state_key in event_rows
- if state_key is not None or self.hs.is_mine_id(event_id)
- ]
+ " WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")",
+ (True,),
+ )
+
+ # synapse tries to take out an exclusive lock on room_depth whenever it
+ # persists events (because upsert), and once we run this update, we
+ # will block that for the rest of our transaction.
+ #
+ # So, let's stick it at the end so that we don't block event
+ # persistence.
+ #
+ # We do this by calculating the minimum depth of the backwards
+ # extremities. However, the events in event_backward_extremities
+ # are ones we don't have yet so we need to look at the events that
+ # point to it via event_edges table.
+ txn.execute("""
+ SELECT COALESCE(MIN(depth), 0)
+ FROM event_backward_extremities AS eb
+ INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+ INNER JOIN events AS e ON e.event_id = eg.event_id
+ WHERE eb.room_id = ?
+ """, (room_id,))
+ min_depth, = txn.fetchone()
+
+ logger.info("[purge] updating room_depth to %d", min_depth)
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (min_depth, room_id,)
+ )
+
+ # finally, drop the temp table. this will commit the txn in sqlite,
+ # so make sure to keep this actually last.
+ txn.execute(
+ "DROP TABLE events_to_purge"
)
logger.info("[purge] done")
@@ -2260,7 +2236,7 @@ class EventsStore(SQLBaseStore):
to_2, so_2 = yield self._get_event_ordering(event_id2)
defer.returnValue((to_1, so_1) > (to_2, so_2))
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
res = yield self._simple_select_one(
table="events",
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..f28239a808
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,436 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from collections import namedtuple
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events.utils import prune_event
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
+)
+from synapse.util.metrics import Measure
+
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+ def get_received_ts(self, event_id):
+ """Get received_ts (when it was persisted) for the event.
+
+ Raises an exception for unknown events.
+
+ Args:
+ event_id (str)
+
+ Returns:
+ Deferred[int|None]: Timestamp in milliseconds, or None for events
+ that were persisted before received_ts was implemented.
+ """
+ return self._simple_select_one_onecol(
+ table="events",
+ keyvalues={
+ "event_id": event_id,
+ },
+ retcol="received_ts",
+ desc="get_received_ts",
+ )
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ events = yield self._get_events(
+ [event_id],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not events and not allow_none:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ defer.returnValue(events[0] if events else None)
+
+ @defer.inlineCallbacks
+ def get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self._get_events(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ defer.returnValue({e.event_id: e for e in events})
+
+ @defer.inlineCallbacks
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ defer.returnValue([])
+
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
+ event_entry_map = self._get_events_from_cache(
+ event_ids,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ allow_rejected=allow_rejected,
+ )
+
+ event_entry_map.update(missing_events)
+
+ events = []
+ for event_id in event_id_list:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if allow_rejected or not entry.event.rejected_reason:
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ defer.returnValue(events)
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (list(str)): list of event_ids to fetch
+ allow_rejected (bool): Whether to teturn events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None,
+ update_metrics=update_metrics,
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+
+ def _fetch_event_list(self, conn, event_list):
+ """Handle a load of requests from the _event_fetch_list queue
+
+ Args:
+ conn (twisted.enterprise.adbapi.Connection): database connection
+
+ event_list (list[Tuple[list[str], Deferred]]):
+ The fetch requests. Each entry consists of a list of event
+ ids to be fetched, and a deferred to be completed once the
+ events have been fetched.
+
+ """
+ with Measure(self._clock, "_fetch_event_list"):
+ try:
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+
+ rows = self._new_transaction(
+ conn, "do_fetch", [], [],
+ self._fetch_event_rows, event_ids,
+ )
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ # We only want to resolve deferreds from the main thread
+ def fire(lst, res):
+ for ids, d in lst:
+ if not d.called:
+ try:
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
+ except Exception:
+ logger.exception("Failed to callback")
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(e)
+
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list)
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+ """
+ if not events:
+ defer.returnValue({})
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, events_d)
+ )
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process(
+ "fetch_events",
+ self.runWithConnection,
+ self._do_fetch,
+ )
+
+ logger.debug("Loading %d events", len(events))
+ with PreserveLoggingContext():
+ rows = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = yield make_deferred_yieldable(defer.gatherResults(
+ [
+ run_in_background(
+ self._get_event_from_row,
+ row["internal_metadata"], row["json"], row["redacts"],
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True
+ ))
+
+ defer.returnValue({
+ e.event.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
+ rows = []
+ N = 200
+ for i in range(1 + len(events) // N):
+ evs = events[i * N:(i + 1) * N]
+ if not evs:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"] * len(evs)),)
+
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ @defer.inlineCallbacks
+ def _get_event_from_row(self, internal_metadata, js, redacted,
+ rejected_reason=None):
+ with Measure(self._clock, "_get_event_from_row"):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row_rejected_reason",
+ )
+
+ original_ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
+
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": redacted_event.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row_redactions",
+ )
+
+ redacted_event.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False,
+ allow_none=True,
+ )
+
+ if because:
+ # It's fine to do add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = because
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
+ )
+
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+ defer.returnValue(cache_entry)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 78b1e30945..2d5896c5b4 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from canonicaljson import encode_canonical_json, json
+
from twisted.internet import defer
-from ._base import SQLBaseStore
-from synapse.api.errors import SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from canonicaljson import encode_canonical_json
-import simplejson as json
+from ._base import SQLBaseStore
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(str(def_json).decode("utf-8")))
+ defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
new file mode 100644
index 0000000000..592d1b4c2a
--- /dev/null
+++ b/synapse/storage/group_server.py
@@ -0,0 +1,1252 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+
+from ._base import SQLBaseStore
+
+# The category ID for the "default" category. We don't store as null in the
+# database to avoid the fun of null != null
+_DEFAULT_CATEGORY_ID = ""
+_DEFAULT_ROLE_ID = ""
+
+
+class GroupServerStore(SQLBaseStore):
+ def set_group_join_policy(self, group_id, join_policy):
+ """Set the join policy of a group.
+
+ join_policy can be one of:
+ * "invite"
+ * "open"
+ """
+ return self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues={
+ "join_policy": join_policy,
+ },
+ desc="set_group_join_policy",
+ )
+
+ def get_group(self, group_id):
+ return self._simple_select_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=(
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public", "join_policy",
+ ),
+ allow_none=True,
+ desc="get_group",
+ )
+
+ def get_users_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_users",
+ keyvalues=keyvalues,
+ retcols=("user_id", "is_public", "is_admin",),
+ desc="get_users_in_group",
+ )
+
+ def get_invited_users_in_group(self, group_id):
+ # TODO: Pagination
+
+ return self._simple_select_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcol="user_id",
+ desc="get_invited_users_in_group",
+ )
+
+ def get_rooms_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_rooms",
+ keyvalues=keyvalues,
+ retcols=("room_id", "is_public",),
+ desc="get_rooms_in_group",
+ )
+
+ def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ """Get the rooms and categories that should be included in a summary request
+
+ Returns ([rooms], [categories])
+ """
+ def _get_rooms_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT room_id, is_public, category_id, room_order
+ FROM group_summary_rooms
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ rooms = [
+ {
+ "room_id": row[0],
+ "is_public": row[1],
+ "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT category_id, is_public, profile, cat_order
+ FROM group_summary_room_categories
+ INNER JOIN group_room_categories USING (group_id, category_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ categories = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return rooms, categories
+ return self.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
+
+ def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
+ return self.runInteraction(
+ "add_room_to_summary", self._add_room_to_summary_txn,
+ group_id, room_id, category_id, order, is_public,
+ )
+
+ def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
+ is_public):
+ """Add (or update) room's entry in summary.
+
+ Args:
+ group_id (str)
+ room_id (str)
+ category_id (str): If not None then adds the category to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the room at that position, e.g.
+ an order of 1 will put the room first. Otherwise, the room gets
+ added to the end.
+ """
+ room_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ retcol="room_id",
+ allow_none=True,
+ )
+ if not room_in_group:
+ raise SynapseError(400, "room not in group")
+
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+ else:
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ raise SynapseError(400, "Category doesn't exist")
+
+ # TODO: Check category is part of summary already
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_room_categories
+ (group_id, category_id, cat_order)
+ SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
+ FROM group_summary_room_categories
+ WHERE group_id = ? AND category_id = ?
+ """, (group_id, category_id, group_id, category_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ "category_id": category_id,
+ },
+ retcols=("room_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other room orders that come after the given order
+ sql = """
+ UPDATE group_summary_rooms SET room_order = room_order + 1
+ WHERE group_id = ? AND category_id = ? AND room_order >= ?
+ """
+ txn.execute(sql, (group_id, category_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
+ WHERE group_id = ? AND category_id = ?
+ """
+ txn.execute(sql, (group_id, category_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["room_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_rooms",
+ values={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ "room_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_room_from_summary(self, group_id, room_id, category_id):
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+
+ return self._simple_delete(
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ desc="remove_room_from_summary",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("category_id", "is_public", "profile"),
+ desc="get_group_categories",
+ )
+
+ defer.returnValue({
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, category_id):
+ category = yield self._simple_select_one(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_category",
+ )
+
+ category["profile"] = json.loads(category["profile"])
+
+ defer.returnValue(category)
+
+ def upsert_group_category(self, group_id, category_id, profile, is_public):
+ """Add/update room category for group
+ """
+ insertion_values = {}
+ update_values = {"category_id": category_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_category",
+ )
+
+ def remove_group_category(self, group_id, category_id):
+ return self._simple_delete(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ desc="remove_group_category",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("role_id", "is_public", "profile"),
+ desc="get_group_roles",
+ )
+
+ defer.returnValue({
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, role_id):
+ role = yield self._simple_select_one(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_role",
+ )
+
+ role["profile"] = json.loads(role["profile"])
+
+ defer.returnValue(role)
+
+ def upsert_group_role(self, group_id, role_id, profile, is_public):
+ """Add/remove user role
+ """
+ insertion_values = {}
+ update_values = {"role_id": role_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_role",
+ )
+
+ def remove_group_role(self, group_id, role_id):
+ return self._simple_delete(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ desc="remove_group_role",
+ )
+
+ def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
+ return self.runInteraction(
+ "add_user_to_summary", self._add_user_to_summary_txn,
+ group_id, user_id, role_id, order, is_public,
+ )
+
+ def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
+ is_public):
+ """Add (or update) user's entry in summary.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ role_id (str): If not None then adds the role to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the user at that position, e.g.
+ an order of 1 will put the user first. Otherwise, the user gets
+ added to the end.
+ """
+ user_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+ if not user_in_group:
+ raise SynapseError(400, "user not in group")
+
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+ else:
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ raise SynapseError(400, "Role doesn't exist")
+
+ # TODO: Check role is part of the summary already
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_roles
+ (group_id, role_id, role_order)
+ SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
+ FROM group_summary_roles
+ WHERE group_id = ? AND role_id = ?
+ """, (group_id, role_id, group_id, role_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ "role_id": role_id,
+ },
+ retcols=("user_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other users orders that come after the given order
+ sql = """
+ UPDATE group_summary_users SET user_order = user_order + 1
+ WHERE group_id = ? AND role_id = ? AND user_order >= ?
+ """
+ txn.execute(sql, (group_id, role_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
+ WHERE group_id = ? AND role_id = ?
+ """
+ txn.execute(sql, (group_id, role_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["user_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_users",
+ values={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ "user_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_user_from_summary(self, group_id, user_id, role_id):
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+
+ return self._simple_delete(
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ desc="remove_user_from_summary",
+ )
+
+ def get_users_for_summary_by_role(self, group_id, include_private=False):
+ """Get the users and roles that should be included in a summary request
+
+ Returns ([users], [roles])
+ """
+ def _get_users_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT user_id, is_public, role_id, user_order
+ FROM group_summary_users
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ users = [
+ {
+ "user_id": row[0],
+ "is_public": row[1],
+ "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT role_id, is_public, profile, role_order
+ FROM group_summary_roles
+ INNER JOIN group_roles USING (group_id, role_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ roles = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return users, roles
+ return self.runInteraction(
+ "get_users_for_summary_by_role", _get_users_for_summary_txn
+ )
+
+ def is_user_in_group(self, user_id, group_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ desc="is_user_in_group",
+ ).addCallback(lambda r: bool(r))
+
+ def is_user_admin_in_group(self, group_id, user_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="is_admin",
+ allow_none=True,
+ desc="is_user_admin_in_group",
+ )
+
+ def add_group_invite(self, group_id, user_id):
+ """Record that the group server has invited a user
+ """
+ return self._simple_insert(
+ table="group_invites",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ desc="add_group_invite",
+ )
+
+ def is_user_invited_to_local_group(self, group_id, user_id):
+ """Has the group server invited a user?
+ """
+ return self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ desc="is_user_invited_to_local_group",
+ allow_none=True,
+ )
+
+ def get_users_membership_info_in_group(self, group_id, user_id):
+ """Get a dict describing the membership of a user in a group.
+
+ Example if joined:
+
+ {
+ "membership": "join",
+ "is_public": True,
+ "is_privileged": False,
+ }
+
+ Returns an empty dict if the user is not join/invite/etc
+ """
+ def _get_users_membership_in_group_txn(txn):
+ row = self._simple_select_one_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("is_admin", "is_public"),
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "join",
+ "is_public": row["is_public"],
+ "is_privileged": row["is_admin"],
+ }
+
+ row = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "invite",
+ }
+
+ return {}
+
+ return self.runInteraction(
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+ )
+
+ def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
+ local_attestation=None, remote_attestation=None):
+ """Add a user to the group server.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ is_admin (bool)
+ is_public (bool)
+ local_attestation (dict): The attestation the GS created to give
+ to the remote server. Optional if the user and group are on the
+ same server
+ remote_attestation (dict): The attestation given to GS by remote
+ server. Optional if the user and group are on the same server
+ """
+ def _add_user_to_group_txn(txn):
+ self._simple_insert_txn(
+ txn,
+ table="group_users",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "is_public": is_public,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ },
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ },
+ )
+
+ return self.runInteraction(
+ "add_user_to_group", _add_user_to_group_txn
+ )
+
+ def remove_user_from_group(self, group_id, user_id):
+ def _remove_user_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+ def add_room_to_group(self, group_id, room_id, is_public):
+ return self._simple_insert(
+ table="group_rooms",
+ values={
+ "group_id": group_id,
+ "room_id": room_id,
+ "is_public": is_public,
+ },
+ desc="add_room_to_group",
+ )
+
+ def update_room_in_group_visibility(self, group_id, room_id, is_public):
+ return self._simple_update(
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ updatevalues={
+ "is_public": is_public,
+ },
+ desc="update_room_in_group_visibility",
+ )
+
+ def remove_room_from_group(self, group_id, room_id):
+ def _remove_room_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+ return self.runInteraction(
+ "remove_room_from_group", _remove_room_from_group_txn,
+ )
+
+ def get_publicised_groups_for_user(self, user_id):
+ """Get all groups a user is publicising
+ """
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ "is_publicised": True,
+ },
+ retcol="group_id",
+ desc="get_publicised_groups_for_user",
+ )
+
+ def update_group_publicity(self, group_id, user_id, publicise):
+ """Update whether the user is publicising their membership of the group
+ """
+ return self._simple_update_one(
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "is_publicised": publicise,
+ },
+ desc="update_group_publicity"
+ )
+
+ @defer.inlineCallbacks
+ def register_user_group_membership(self, group_id, user_id, membership,
+ is_admin=False, content={},
+ local_attestation=None,
+ remote_attestation=None,
+ is_publicised=False,
+ ):
+ """Registers that a local user is a member of a (local or remote) group.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ membership (str)
+ is_admin (bool)
+ content (dict): Content of the membership, e.g. includes the inviter
+ if the user has been invited.
+ local_attestation (dict): If remote group then store the fact that we
+ have given out an attestation, else None.
+ remote_attestation (dict): If remote group then store the remote
+ attestation from the group, else None.
+ """
+ def _register_user_group_membership_txn(txn, next_id):
+ # TODO: Upsert?
+ self._simple_delete_txn(
+ txn,
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ table="local_group_membership",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "membership": membership,
+ "is_publicised": is_publicised,
+ "content": json.dumps(content),
+ },
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="local_group_updates",
+ values={
+ "stream_id": next_id,
+ "group_id": group_id,
+ "user_id": user_id,
+ "type": "membership",
+ "content": json.dumps({"membership": membership, "content": content}),
+ }
+ )
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+ # TODO: Insert profile to ensure it comes down stream if its a join.
+
+ if membership == "join":
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ }
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ }
+ )
+ else:
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ return next_id
+
+ with self._group_updates_id_gen.get_next() as next_id:
+ res = yield self.runInteraction(
+ "register_user_group_membership",
+ _register_user_group_membership_txn, next_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, name, avatar_url, short_description,
+ long_description,):
+ yield self._simple_insert(
+ table="groups",
+ values={
+ "group_id": group_id,
+ "name": name,
+ "avatar_url": avatar_url,
+ "short_description": short_description,
+ "long_description": long_description,
+ "is_public": True,
+ },
+ desc="create_group",
+ )
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, profile,):
+ yield self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues=profile,
+ desc="update_group_profile",
+ )
+
+ def get_attestations_need_renewals(self, valid_until_ms):
+ """Get all attestations that need to be renewed until givent time
+ """
+ def _get_attestations_need_renewals_txn(txn):
+ sql = """
+ SELECT group_id, user_id FROM group_attestations_renewals
+ WHERE valid_until_ms <= ?
+ """
+ txn.execute(sql, (valid_until_ms,))
+ return self.cursor_to_dict(txn)
+ return self.runInteraction(
+ "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+ )
+
+ def update_attestation_renewal(self, group_id, user_id, attestation):
+ """Update an attestation that we have renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ },
+ desc="update_attestation_renewal",
+ )
+
+ def update_remote_attestion(self, group_id, user_id, attestation):
+ """Update an attestation that a remote has renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ "attestation_json": json.dumps(attestation)
+ },
+ desc="update_remote_attestion",
+ )
+
+ def remove_attestation_renewal(self, group_id, user_id):
+ """Remove an attestation that we thought we should renew, but actually
+ shouldn't. Ideally this would never get called as we would never
+ incorrectly try and do attestations for local users on local groups.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ """
+ return self._simple_delete(
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ desc="remove_attestation_renewal",
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_attestation(self, group_id, user_id):
+ """Get the attestation that proves the remote agrees that the user is
+ in the group.
+ """
+ row = yield self._simple_select_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("valid_until_ms", "attestation_json"),
+ desc="get_remote_attestation",
+ allow_none=True,
+ )
+
+ now = int(self._clock.time_msec())
+ if row and now < row["valid_until_ms"]:
+ defer.returnValue(json.loads(row["attestation_json"]))
+
+ defer.returnValue(None)
+
+ def get_joined_groups(self, user_id):
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ },
+ retcol="group_id",
+ desc="get_joined_groups",
+ )
+
+ def get_all_groups_for_user(self, user_id, now_token):
+ def _get_all_groups_for_user_txn(txn):
+ sql = """
+ SELECT group_id, type, membership, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND membership != 'leave'
+ AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, now_token,))
+ return [
+ {
+ "group_id": row[0],
+ "type": row[1],
+ "membership": row[2],
+ "content": json.loads(row[3]),
+ }
+ for row in txn
+ ]
+ return self.runInteraction(
+ "get_all_groups_for_user", _get_all_groups_for_user_txn,
+ )
+
+ def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_entity_changed(
+ user_id, from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_groups_changes_for_user_txn(txn):
+ sql = """
+ SELECT group_id, membership, type, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, from_token, to_token,))
+ return [{
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ } for group_id, membership, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+ )
+
+ def get_all_groups_changes(self, from_token, to_token, limit):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+ from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_all_groups_changes_txn(txn):
+ sql = """
+ SELECT stream_id, group_id, user_id, type, content
+ FROM local_group_updates
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit,))
+ return [(
+ stream_id,
+ group_id,
+ user_id,
+ gtype,
+ json.loads(content_json),
+ ) for stream_id, group_id, user_id, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_all_groups_changes", _get_all_groups_changes_txn,
+ )
+
+ def get_group_stream_token(self):
+ return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3b5e0a4fb9..f547977600 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,19 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import hashlib
+import logging
-from twisted.internet import defer
+import six
-import OpenSSL
from signedjson.key import decode_verify_key_bytes
-import hashlib
-import logging
+import OpenSSL
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@@ -72,7 +82,7 @@ class KeyStore(SQLBaseStore):
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
- "tls_certificate": buffer(tls_certificate_bytes),
+ "tls_certificate": db_binary_type(tls_certificate_bytes),
},
desc="store_server_certificate",
)
@@ -92,7 +102,7 @@ class KeyStore(SQLBaseStore):
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
- key_id, str(verify_key_bytes)
+ key_id, bytes(verify_key_bytes)
))
@defer.inlineCallbacks
@@ -113,30 +123,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key
defer.returnValue(keys)
- @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
- key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up
- ts_now_ms (int): The time now in milliseconds
- verification_key (VerifyKey): The NACL verify key.
+ time_now_ms (int): The time now in milliseconds
+ verify_key (nacl.signing.VerifyKey): The NACL verify key.
"""
- yield self._simple_upsert(
- table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
- },
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": buffer(verify_key.encode()),
- },
- desc="store_server_verify_key",
- )
+ key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ def _txn(txn):
+ self._simple_upsert_txn(
+ txn,
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ },
+ values={
+ "from_server": from_server,
+ "ts_added_ms": time_now_ms,
+ "verify_key": db_binary_type(verify_key.encode()),
+ },
+ )
+ txn.call_after(
+ self._get_server_verify_key.invalidate,
+ (server_name, key_id)
+ )
+
+ return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
@@ -165,7 +182,7 @@ class KeyStore(SQLBaseStore):
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
- "key_json": buffer(key_json_bytes),
+ "key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 82bb61b811..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
-from ._base import SQLBaseStore
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def get_default_thumbnails(self, top_level_type, sub_type):
- return []
+ def __init__(self, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ update_name='local_media_repository_url_idx',
+ index_name='local_media_repository_url_idx',
+ table='local_media_repository',
+ columns=['created_ts'],
+ where_clause='url_cache IS NOT NULL',
+ )
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
@@ -62,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
@@ -74,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
@@ -86,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None
return dict(zip((
- 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+ 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row))
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
- def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+ def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts):
return self._simple_insert(
"local_media_repository_url_cache",
@@ -101,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url,
"response_code": response_code,
"etag": etag,
- "expires": expires,
+ "expires_ts": expires_ts,
"og": og,
"media_id": media_id,
"download_ts": download_ts,
@@ -166,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -174,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -238,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore):
},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+ def get_expired_url_cache(self, now_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository_url_cache"
+ " WHERE expires_ts < ?"
+ " ORDER BY expires_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_expired_url_cache_txn(txn):
+ txn.execute(sql, (now_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+ def delete_url_cache(self, media_ids):
+ if len(media_ids) == 0:
+ return
+
+ sql = (
+ "DELETE FROM local_media_repository_url_cache"
+ " WHERE media_id = ?"
+ )
+
+ def _delete_url_cache_txn(txn):
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+ def get_url_cache_media_before(self, before_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository"
+ " WHERE created_ts < ? AND url_cache IS NOT NULL"
+ " ORDER BY created_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_url_cache_media_before_txn(txn):
+ txn.execute(sql, (before_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction(
+ "get_url_cache_media_before", _get_url_cache_media_before_txn,
+ )
+
+ def delete_url_cache_media(self, media_ids):
+ if len(media_ids) == 0:
+ return
+
+ def _delete_url_cache_media_txn(txn):
+ sql = (
+ "DELETE FROM local_media_repository"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ sql = (
+ "DELETE FROM local_media_repository_thumbnails"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction(
+ "delete_url_cache_media", _delete_url_cache_media_txn,
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 72b670b83b..b290f834b3 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,13 +20,12 @@ import logging
import os
import re
-
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 43
+SCHEMA_VERSION = 50
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
+
+ Args:
+ db_conn:
+ database_engine:
+ config (synapse.config.homeserver.HomeServerConfig|None):
+ application config, or None if we are connecting to an existing
+ database which we expect to be configured already
"""
try:
cur = db_conn.cursor()
@@ -64,9 +71,13 @@ def prepare_database(db_conn, database_engine, config):
else:
_setup_new_database(cur, database_engine)
+ # check if any of our configured dynamic modules want a database
+ if config is not None:
+ _apply_module_schemas(cur, database_engine, config)
+
cur.close()
db_conn.commit()
- except:
+ except Exception:
db_conn.rollback()
raise
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
+def _apply_module_schemas(txn, database_engine, config):
+ """Apply the module schemas for the dynamic modules, if any
+
+ Args:
+ cur: database cursor
+ database_engine: synapse database engine class
+ config (synapse.config.homeserver.HomeServerConfig):
+ application config
+ """
+ for (mod, _config) in config.password_providers:
+ if not hasattr(mod, 'get_db_schema_files'):
+ continue
+ modname = ".".join((mod.__module__, mod.__name__))
+ _apply_module_schema_files(
+ txn, database_engine, modname, mod.get_db_schema_files(),
+ )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+ """Apply the module schemas for a single module
+
+ Args:
+ cur: database cursor
+ database_engine: synapse database engine class
+ modname (str): fully qualified name of the module
+ names_and_streams (Iterable[(str, file)]): the names and streams of
+ schemas to be applied
+ """
+ cur.execute(
+ database_engine.convert_param_style(
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+ ),
+ (modname,)
+ )
+ applied_deltas = set(d for d, in cur)
+ for (name, stream) in names_and_streams:
+ if name in applied_deltas:
+ continue
+
+ root_name, ext = os.path.splitext(name)
+ if ext != '.sql':
+ raise PrepareDatabaseException(
+ "only .sql files are currently supported for module schemas",
+ )
+
+ logger.info("applying schema %s for %s", name, modname)
+ for statement in get_statements(stream):
+ cur.execute(statement)
+
+ # Mark as done.
+ cur.execute(
+ database_engine.convert_param_style(
+ "INSERT INTO applied_module_schemas (module_name, file)"
+ " VALUES (?,?)",
+ ),
+ (modname, name)
+ )
+
+
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 9e9d3c2591..a0c7a0dc87 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from collections import namedtuple
+
+from twisted.internet import defer
+
from synapse.api.constants import PresenceState
+from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from collections import namedtuple
-from twisted.internet import defer
+from ._base import SQLBaseStore
class UserPresenceState(namedtuple("UserPresenceState",
@@ -115,11 +118,7 @@ class PresenceStore(SQLBaseStore):
" AND user_id IN (%s)"
)
- batches = (
- presence_states[i:i + 50]
- for i in xrange(0, len(presence_states), 50)
- )
- for states in batches:
+ for states in batch_iter(presence_states, 50):
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..60295da254 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,15 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.storage.roommember import ProfileInfo
+
from ._base import SQLBaseStore
-class ProfileStore(SQLBaseStore):
- def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
+class ProfileWorkerStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def get_profileinfo(self, user_localpart):
+ try:
+ profile = yield self._simple_select_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcols=("displayname", "avatar_url"),
+ desc="get_profileinfo",
+ )
+ except StoreError as e:
+ if e.code == 404:
+ # no match
+ defer.returnValue(ProfileInfo(None, None))
+ return
+ else:
+ raise
+
+ defer.returnValue(
+ ProfileInfo(
+ avatar_url=profile['avatar_url'],
+ display_name=profile['displayname'],
+ )
)
def get_profile_displayname(self, user_localpart):
@@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_displayname",
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
- desc="set_profile_displayname",
- )
-
def get_profile_avatar_url(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",
@@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_from_remote_profile_cache(self, user_id):
+ return self._simple_select_one(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("displayname", "avatar_url",),
+ allow_none=True,
+ desc="get_from_remote_profile_cache",
+ )
+
+
+class ProfileStore(ProfileWorkerStore):
+ def create_profile(self, user_localpart):
+ return self._simple_insert(
+ table="profiles",
+ values={"user_id": user_localpart},
+ desc="create_profile",
+ )
+
+ def set_profile_displayname(self, user_localpart, new_displayname):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"displayname": new_displayname},
+ desc="set_profile_displayname",
+ )
+
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self._simple_update_one(
table="profiles",
@@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore):
updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)
+
+ def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ """Ensure we are caching the remote user's profiles.
+
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ return self._simple_upsert(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
+ )
+
+ def update_remote_profile_cache(self, user_id, displayname, avatar_url):
+ return self._simple_update(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="update_remote_profile_cache",
+ )
+
+ @defer.inlineCallbacks
+ def maybe_delete_remote_profile_cache(self, user_id):
+ """Check if we still care about the remote user's profile, and if we
+ don't then remove their profile from the cache
+ """
+ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+ if not subscribed:
+ yield self._simple_delete(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ desc="delete_remote_profile_cache",
+ )
+
+ def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ """Get all users who haven't been checked since `last_checked`
+ """
+ def _get_remote_profile_cache_entries_that_expire_txn(txn):
+ sql = """
+ SELECT user_id, displayname, avatar_url
+ FROM remote_profile_cache
+ WHERE last_check < ?
+ """
+
+ txn.execute(sql, (last_checked,))
+
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_remote_profile_cache_entries_that_expire",
+ _get_remote_profile_cache_entries_that_expire_txn,
+ )
+
+ @defer.inlineCallbacks
+ def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
+
+ res = yield self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,14 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
+import abc
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
-import logging
-import simplejson as json
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -48,7 +57,43 @@ def _load_rules(rawrules, enabled_map):
return rules
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self._get_cache_dict(
+ db_conn, "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache", push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
@@ -89,6 +134,22 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
@@ -124,6 +185,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
+ @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
@@ -133,9 +195,11 @@ class PushRuleStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._bulk_get_push_rules_for_room(
- event.room_id, state_group, context.current_state_ids, event=event
+ current_state_ids = yield context.get_current_state_ids(self)
+ result = yield self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state_ids, event=event
)
+ defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -185,18 +249,6 @@ class PushRuleStore(SQLBaseStore):
if uid in local_users_in_room:
user_ids.add(uid)
- forgotten = yield self.who_forgot_in_room(
- event.room_id, on_invalidate=cache_context.invalidate,
- )
-
- for row in forgotten:
- user_id = row["user_id"]
- event_id = row["event_id"]
-
- mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
- if event_id == mem_id:
- user_ids.discard(user_id)
-
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
@@ -228,6 +280,8 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results)
+
+class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
@@ -526,21 +580,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 34d2f82b7f..8443bd4c1b 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer
+import logging
+import types
+
+from canonicaljson import encode_canonical_json, json
-from canonicaljson import encode_canonical_json
+from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-import logging
-import simplejson as json
-import types
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows)
- def get_pushers_stream_token(self):
- return self._pushers_id_gen.get_current_token()
-
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
@@ -198,56 +196,74 @@ class PusherStore(SQLBaseStore):
defer.returnValue(result)
+
+class PusherStore(PusherWorkerStore):
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id:
- def f(txn):
- newly_inserted = self._simple_upsert_txn(
- txn,
- "pushers",
- {
- "app_id": app_id,
- "pushkey": pushkey,
- "user_name": user_id,
- },
- {
- "access_token": access_token,
- "kind": kind,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "ts": pushkey_ts,
- "lang": lang,
- "data": encode_canonical_json(data),
- "last_stream_ordering": last_stream_ordering,
- "profile_tag": profile_tag,
- "id": stream_id,
- },
- )
- if newly_inserted:
- # get_if_user_has_pusher only cares if the user has
- # at least *one* pusher.
- txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+ # no need to lock because `pushers` has a unique key on
+ # (app_id, pushkey, user_name) so _simple_upsert will retry
+ newly_inserted = yield self._simple_upsert(
+ table="pushers",
+ keyvalues={
+ "app_id": app_id,
+ "pushkey": pushkey,
+ "user_name": user_id,
+ },
+ values={
+ "access_token": access_token,
+ "kind": kind,
+ "app_display_name": app_display_name,
+ "device_display_name": device_display_name,
+ "ts": pushkey_ts,
+ "lang": lang,
+ "data": encode_canonical_json(data),
+ "last_stream_ordering": last_stream_ordering,
+ "profile_tag": profile_tag,
+ "id": stream_id,
+ },
+ desc="add_pusher",
+ lock=False,
+ )
- yield self.runInteraction("add_pusher", f)
+ if newly_inserted:
+ yield self.runInteraction(
+ "add_pusher",
+ self._invalidate_cache_and_stream,
+ self.get_if_user_has_pusher, (user_id,)
+ )
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
- txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_if_user_has_pusher, (user_id,)
+ )
self._simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
)
- self._simple_upsert_txn(
+
+ # it's possible for us to end up with duplicate rows for
+ # (app_id, pushkey, user_id) at different stream_ids, but that
+ # doesn't really matter.
+ self._simple_insert_txn(
txn,
- "deleted_pushers",
- {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
- {"stream_id": stream_id},
+ table="deleted_pushers",
+ values={
+ "stream_id": stream_id,
+ "app_id": app_id,
+ "pushkey": pushkey,
+ "user_id": user_id,
+ },
)
with self._pushers_id_gen.get_next() as stream_id:
@@ -310,9 +326,12 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
+ # no need to lock because `pusher_throttle` has a primary key on
+ # (pusher, room_id) so _simple_upsert will retry
yield self._simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
- desc="set_throttle_params"
+ desc="set_throttle_params",
+ lock=False,
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f42b8014c7..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,52 +14,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+import abc
+import logging
+
+from canonicaljson import json
from twisted.internet import defer
-import logging
-import ujson as json
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__)
-class ReceiptsStore(SQLBaseStore):
- def __init__(self, hs):
- super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
- if receipt_type != "m.read":
- return
-
- # Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get(
- room_id, None, update_metrics=False,
- )
-
- if res:
- if isinstance(res, defer.Deferred) and res.called:
- res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
-
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore):
"""
room_ids = set(room_ids)
- if from_key:
+ if from_key is not None:
+ # Only ask the database about rooms where there have been new
+ # receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
@@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore):
from the start.
Returns:
- list: A list of receipts.
+ Deferred[list]: A list of receipts.
+ """
+ if from_key is not None:
+ # Check the cache first to see if any new receipts have been added
+ # since`from_key`. If not we can no-op.
+ if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ defer.succeed([])
+
+ return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+ @cachedInlineCallbacks(num_args=3, tree=True)
+ def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """See get_linearized_receipts_for_room
"""
def f(txn):
if from_key:
@@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore):
"content": content,
}])
- @cachedList(cached_method_name="get_linearized_receipts_for_room",
+ @cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
@@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore):
}
defer.returnValue(results)
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ args = [last_id, current_id]
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns either an ObservableDeferred or the raw result
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
+
+ # first handle the Deferred case
+ if isinstance(res, defer.Deferred):
+ if res.called:
+ res = res.result
+ else:
+ res = None
+
+ if res and user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ res = self._simple_select_one_txn(
+ txn,
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True
+ )
+
+ stream_ordering = int(res["stream_ordering"]) if res else None
+
+ # We don't want to clobber receipts for more recent events, so we
+ # have to compare orderings of existing receipts
+ if stream_ordering is not None:
+ sql = (
+ "SELECT stream_ordering, event_id FROM events"
+ " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+ )
+ txn.execute(sql, (room_id, receipt_type, user_id))
+
+ for so, eid in txn:
+ if int(so) >= stream_ordering:
+ logger.debug(
+ "Ignoring new receipt for %s in favour of existing "
+ "one for later event %s",
+ event_id, eid,
+ )
+ return False
+
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
@@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
@@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore):
(user_id, room_id, receipt_type)
)
- res = self._simple_select_one_txn(
- txn,
- table="events",
- retcols=["topological_ordering", "stream_ordering"],
- keyvalues={"event_id": event_id},
- allow_none=True
- )
-
- topological_ordering = int(res["topological_ordering"]) if res else None
- stream_ordering = int(res["stream_ordering"]) if res else None
-
- # We don't want to clobber receipts for more recent events, so we
- # have to compare orderings of existing receipts
- sql = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
- " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
- )
-
- txn.execute(sql, (room_id, receipt_type, user_id))
-
- if topological_ordering:
- for to, so, _ in txn:
- if int(to) > topological_ordering:
- return False
- elif int(to) == topological_ordering and int(so) >= stream_ordering:
- return False
-
self._simple_delete_txn(
txn,
table="receipts_linearized",
@@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore):
}
)
- if receipt_type == "m.read" and topological_ordering:
+ if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn,
room_id=room_id,
user_id=user_id,
- topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
@@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn(
txn,
@@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
-
- return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_receipts", get_all_updated_receipts_txn
- )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 20acd58fcf..07333f777d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,17 +15,83 @@
import re
+from six.moves import range
+
from twisted.internet import defer
-from synapse.api.errors import StoreError, Codes
+from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(background_updates.BackgroundUpdateStore):
+class RegistrationWorkerStore(SQLBaseStore):
+ @cached()
+ def get_user_by_id(self, user_id):
+ return self._simple_select_one(
+ table="users",
+ keyvalues={
+ "name": user_id,
+ },
+ retcols=[
+ "name", "password_hash", "is_guest",
+ "consent_version", "consent_server_notice_sent",
+ "appservice_id",
+ ],
+ allow_none=True,
+ desc="get_user_by_id",
+ )
+
+ @cached()
+ def get_user_by_access_token(self, token):
+ """Get a user from the given access token.
+
+ Args:
+ token (str): The access token of a user.
+ Returns:
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ """
+ return self.runInteraction(
+ "get_user_by_access_token",
+ self._query_for_auth,
+ token
+ )
+
+ @defer.inlineCallbacks
+ def is_server_admin(self, user):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user.to_string()},
+ retcol="admin",
+ allow_none=True,
+ desc="is_server_admin",
+ )
+
+ defer.returnValue(res if res else False)
+
+ def _query_for_auth(self, txn, token):
+ sql = (
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
+ " FROM users"
+ " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+ " WHERE token = ?"
+ )
+
+ txn.execute(sql, (token,))
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return rows[0]
+
+ return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+ background_updates.BackgroundUpdateStore):
- def __init__(self, hs):
- super(RegistrationStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(RegistrationStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
@@ -37,12 +103,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
self.register_background_index_update(
- "refresh_tokens_device_index",
- index_name="refresh_tokens_device_id",
- table="refresh_tokens",
- columns=["user_id", "device_id"],
+ "users_creation_ts",
+ index_name="users_creation_ts",
+ table="users",
+ columns=["creation_ts"],
)
+ # we no longer use refresh tokens, but it's possible that some people
+ # might have a background update queued to build this index. Just
+ # clear the background update.
+ self.register_noop_background_update("refresh_tokens_device_index")
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -177,9 +248,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
if create_profile_with_localpart:
+ # set a default displayname serverside to avoid ugly race
+ # between auto-joins and clients trying to set displaynames
txn.execute(
- "INSERT INTO profiles(user_id) VALUES (?)",
- (create_profile_with_localpart,)
+ "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
+ (create_profile_with_localpart, create_profile_with_localpart)
)
self._invalidate_cache_and_stream(
@@ -187,18 +260,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- def get_user_by_id(self, user_id):
- return self._simple_select_one(
- table="users",
- keyvalues={
- "name": user_id,
- },
- retcols=["name", "password_hash", "is_guest"],
- allow_none=True,
- desc="get_user_by_id",
- )
-
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
@@ -236,12 +297,57 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"user_set_password_hash", user_set_password_hash_txn
)
- @defer.inlineCallbacks
+ def user_set_consent_version(self, user_id, consent_version):
+ """Updates the user table to record privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy the user has consented
+ to
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_version': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_version", f)
+
+ def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ """Updates the user table to record that we have sent the user a server
+ notice about privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy we have notified the
+ user about
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_server_notice_sent': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_server_notice_sent", f)
+
def user_delete_access_tokens(self, user_id, except_token_id=None,
- device_id=None,
- delete_refresh_tokens=False):
+ device_id=None):
"""
- Invalidate access/refresh tokens belonging to a user
+ Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
@@ -250,10 +356,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
- delete_refresh_tokens (bool): True to delete refresh tokens as
- well as access tokens.
Returns:
- defer.Deferred:
+ defer.Deferred[list[str, int, str|None, int]]: a list of
+ (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
keyvalues = {
@@ -262,13 +367,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
- if delete_refresh_tokens:
- self._simple_delete_txn(
- txn,
- table="refresh_tokens",
- keyvalues=keyvalues,
- )
-
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
@@ -277,14 +375,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id)
txn.execute(
- "SELECT token FROM access_tokens WHERE %s" % where_clause,
+ "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values
)
- rows = self.cursor_to_dict(txn)
+ tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
- for row in rows:
+ for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream(
- txn, self.get_user_by_access_token, (row["token"],)
+ txn, self.get_user_by_access_token, (token,)
)
txn.execute(
@@ -292,7 +390,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values
)
- yield self.runInteraction(
+ return tokens_and_devices
+
+ return self.runInteraction(
"user_delete_access_tokens", f,
)
@@ -312,34 +412,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("delete_access_token", f)
- @cached()
- def get_user_by_access_token(self, token):
- """Get a user from the given access token.
-
- Args:
- token (str): The access token of a user.
- Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
- """
- return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
- )
-
- @defer.inlineCallbacks
- def is_server_admin(self, user):
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user.to_string()},
- retcol="admin",
- allow_none=True,
- desc="is_server_admin",
- )
-
- defer.returnValue(res if res else False)
-
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
@@ -352,22 +424,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(res if res else False)
- def _query_for_auth(self, txn, token):
- sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
- " FROM users"
- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
- " WHERE token = ?"
- )
-
- txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
- if rows:
- return rows[0]
-
- return None
-
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
@@ -404,15 +460,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
- def user_delete_threepids(self, user_id):
- return self._simple_delete(
- "user_threepids",
- keyvalues={
- "user_id": user_id,
- },
- desc="user_delete_threepids",
- )
-
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
@@ -437,6 +484,35 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+ def count_daily_user_type(self):
+ """
+ Counts 1) native non guest users
+ 2) native guests users
+ 3) bridged users
+ who registered on the homeserver in the past 24 hours
+ """
+ def _count_daily_user_type(txn):
+ yesterday = int(self._clock.time()) - (60 * 60 * 24)
+
+ sql = """
+ SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+ SELECT
+ CASE
+ WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
+ WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
+ WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
+ END AS user_type
+ FROM users
+ WHERE creation_ts > ?
+ ) AS t GROUP BY user_type
+ """
+ results = {'native': 0, 'guest': 0, 'bridged': 0}
+ txn.execute(sql, (yesterday,))
+ for row in txn:
+ results[row[0]] = row[1]
+ return results
+ return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
@@ -464,18 +540,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
- rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):")
found = set()
- for r in rows:
- user_id = r["name"]
+ for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
- for i in xrange(len(found) + 1):
+ for i in range(len(found) + 1):
if i not in found:
return i
@@ -530,3 +604,44 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
except self.database_engine.module.IntegrityError:
ret = yield self.get_3pid_guest_access_token(medium, address)
defer.returnValue(ret)
+
+ def add_user_pending_deactivation(self, user_id):
+ """
+ Adds a user to the table of users who need to be parted from all the rooms they're
+ in
+ """
+ return self._simple_insert(
+ "users_pending_deactivation",
+ values={
+ "user_id": user_id,
+ },
+ desc="add_user_pending_deactivation",
+ )
+
+ def del_user_pending_deactivation(self, user_id):
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ return self._simple_delete(
+ "users_pending_deactivation",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="del_user_pending_deactivation",
+ )
+
+ def get_user_pending_deactivation(self):
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return self._simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 40acb5c4ed..880f047adb 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-
import logging
+from ._base import SQLBaseStore
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..3147fb6827 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+import logging
+import re
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
-import collections
-import logging
-import ujson as json
-import re
-
logger = logging.getLogger(__name__)
@@ -40,7 +40,138 @@ RatelimitOverride = collections.namedtuple(
)
-class RoomStore(SQLBaseStore):
+class RoomWorkerStore(SQLBaseStore):
+ def get_public_room_ids(self):
+ return self._simple_select_onecol(
+ table="rooms",
+ keyvalues={
+ "is_public": True,
+ },
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ @cached(num_args=2, max_entries=100)
+ def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+ """Get pulbic rooms for a particular list, or across all lists.
+
+ Args:
+ stream_id (int)
+ network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+ means the main list, None means all lsits.
+ """
+ return self.runInteraction(
+ "get_public_room_ids_at_stream_id",
+ self.get_public_room_ids_at_stream_id_txn,
+ stream_id, network_tuple=network_tuple
+ )
+
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+ network_tuple):
+ return {
+ rm
+ for rm, vis in self.get_published_at_stream_id_txn(
+ txn, stream_id, network_tuple=network_tuple
+ ).items()
+ if vis
+ }
+
+ def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+ if network_tuple:
+ # We want to get from a particular list. No aggregation required.
+
+ sql = ("""
+ SELECT room_id, visibility FROM public_room_list_stream
+ INNER JOIN (
+ SELECT room_id, max(stream_id) AS stream_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ? %s
+ GROUP BY room_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ if network_tuple.appservice_id is not None:
+ txn.execute(
+ sql % ("AND appservice_id = ? AND network_id = ?",),
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ )
+ else:
+ txn.execute(
+ sql % ("AND appservice_id IS NULL",),
+ (stream_id,)
+ )
+ return dict(txn)
+ else:
+ # We want to get from all lists, so we need to aggregate the results
+
+ logger.info("Executing full list")
+
+ sql = ("""
+ SELECT room_id, visibility
+ FROM public_room_list_stream
+ INNER JOIN (
+ SELECT
+ room_id, max(stream_id) AS stream_id, appservice_id,
+ network_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ?
+ GROUP BY room_id, appservice_id, network_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ txn.execute(
+ sql,
+ (stream_id,)
+ )
+
+ results = {}
+ # A room is visible if its visible on any list.
+ for room_id, visibility in txn:
+ results[room_id] = bool(visibility) or results.get(room_id, False)
+
+ return results
+
+ def get_public_room_changes(self, prev_stream_id, new_stream_id,
+ network_tuple):
+ def get_public_room_changes_txn(txn):
+ then_rooms = self.get_public_room_ids_at_stream_id_txn(
+ txn, prev_stream_id, network_tuple
+ )
+
+ now_rooms_dict = self.get_published_at_stream_id_txn(
+ txn, new_stream_id, network_tuple
+ )
+
+ now_rooms_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if vis
+ )
+ now_rooms_not_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if not vis
+ )
+
+ newly_visible = now_rooms_visible - then_rooms
+ newly_unpublished = now_rooms_not_visible & then_rooms
+
+ return newly_visible, newly_unpublished
+
+ return self.runInteraction(
+ "get_public_room_changes", get_public_room_changes_txn
+ )
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -227,16 +358,6 @@ class RoomStore(SQLBaseStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={
- "is_public": True,
- },
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
def get_room_count(self):
"""Retrieve a list of all rooms
"""
@@ -263,8 +384,8 @@ class RoomStore(SQLBaseStore):
},
)
- self._store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"],
)
def _store_room_name_txn(self, txn, event):
@@ -279,14 +400,14 @@ class RoomStore(SQLBaseStore):
}
)
- self._store_event_search_txn(
- txn, event, "content.name", event.content["name"]
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"],
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
- self._store_event_search_txn(
- txn, event, "content.body", event.content["body"]
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"],
)
def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +429,6 @@ class RoomStore(SQLBaseStore):
event.content[key]
))
- def _store_event_search_txn(self, txn, event, key, value):
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
- txn.execute(
- sql,
- (
- event.event_id, event.room_id, key, value,
- event.internal_metadata.stream_ordering,
- event.origin_server_ts,
- )
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- txn.execute(sql, (event.event_id, event.room_id, key, value,))
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
@@ -353,113 +449,6 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
-
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id, network_tuple=network_tuple
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
- network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
-
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
-
- sql = ("""
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """)
-
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
- )
- else:
- txn.execute(
- sql % ("AND appservice_id IS NULL",),
- (stream_id,)
- )
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
-
- logger.info("Executing full list")
-
- sql = ("""
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """)
-
- txn.execute(
- sql,
- (stream_id,)
- )
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
-
- def get_public_room_changes(self, prev_stream_id, new_stream_id,
- network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
-
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
-
- now_rooms_visible = set(
- rm for rm, vis in now_rooms_dict.items() if vis
- )
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
-
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
-
- return newly_visible, newly_unpublished
-
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
- )
-
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
@@ -509,18 +498,6 @@ class RoomStore(SQLBaseStore):
else:
defer.returnValue(None)
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
@@ -531,75 +508,121 @@ class RoomStore(SQLBaseStore):
},
desc="block_room",
)
- self.is_room_blocked.invalidate((room_id,))
+ yield self.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked, (room_id,),
+ )
+
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction(
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, json FROM events
+ JOIN event_json USING (room_id, event_id)
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ event_json = json.loads(content_json)
+ content = event_json["content"]
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hs.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 457ca288d0..01697ab2c9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
+import logging
from collections import namedtuple
-from ._base import SQLBaseStore
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.events import EventsWorkerStore
+from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
-from synapse.api.constants import Membership, EventTypes
-from synapse.types import get_domain_from_id
-
-import logging
-import ujson as json
-
logger = logging.getLogger(__name__)
@@ -37,6 +39,11 @@ RoomsForUser = namedtuple(
("room_id", "sender", "membership", "event_id", "stream_ordering")
)
+GetRoomsForUserWithStreamOrdering = namedtuple(
+ "_GetRoomsForUserWithStreamOrdering",
+ ("room_id", "stream_ordering",)
+)
+
# We store this using a namedtuple so that we save about 3x space over using a
# dict.
@@ -48,97 +55,7 @@ ProfileInfo = namedtuple(
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-class RoomMemberStore(SQLBaseStore):
- def __init__(self, hs):
- super(RoomMemberStore, self).__init__(hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
-
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ]
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key, event.internal_metadata.stream_ordering
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- }
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(sql, (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ))
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (
- stream_ordering,
- True,
- room_id,
- user_id,
- ))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
@@ -270,12 +187,32 @@ class RoomMemberStore(SQLBaseStore):
return results
@cachedInlineCallbacks(max_entries=500000, iterable=True)
- def get_rooms_for_user(self, user_id):
+ def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+ the rooms the user is in currently, along with the stream ordering
+ of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
+ defer.returnValue(frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
+ ))
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user_id, on_invalidate=None):
+ """Returns a set of room_ids the user is currently joined to
+ """
+ rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ user_id, on_invalidate=on_invalidate,
+ )
defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
@@ -295,89 +232,7 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(user_who_share_room)
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- txn.call_after(self.was_forgotten_at.invalidate_all)
- txn.call_after(self.did_forget.invalidate, (user_id, room_id))
- self._invalidate_cache_and_stream(
- txn, self.who_forgot_in_room, (room_id,)
- )
- return self.runInteraction("forget_membership", f)
-
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
- @cachedInlineCallbacks(num_args=3)
- def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at
- event_id.
-
- event_id must be a membership event."""
- def f(txn):
- sql = (
- "SELECT"
- " forgotten"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " event_id = ?"
- )
- txn.execute(sql, (user_id, room_id, event_id))
- rows = txn.fetchall()
- return rows[0][0]
- forgot = yield self.runInteraction("did_forget_membership_at", f)
- defer.returnValue(forgot == 1)
-
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
-
+ @defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
@@ -387,11 +242,13 @@ class RoomMemberStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._get_joined_users_from_context(
- event.room_id, state_group, context.current_state_ids,
+ current_state_ids = yield context.get_current_state_ids(self)
+ result = yield self._get_joined_users_from_context(
+ event.room_id, state_group, current_state_ids,
event=event,
context=context,
)
+ defer.returnValue(result)
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -419,7 +276,7 @@ class RoomMemberStore(SQLBaseStore):
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in current_state_ids.iteritems()
+ for key, e_id in iteritems(current_state_ids)
if key[0] == EventTypes.Member
]
@@ -436,7 +293,7 @@ class RoomMemberStore(SQLBaseStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in context.delta_ids.iteritems()
+ for key, e_id in iteritems(context.delta_ids)
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -533,6 +390,46 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(True)
+ @cachedInlineCallbacks()
+ def was_host_joined(self, room_id, host):
+ """Check whether the server is or ever was in the room.
+
+ Args:
+ room_id (str)
+ host (str)
+
+ Returns:
+ Deferred: Resolves to True if the host is/was in the room, otherwise
+ False.
+ """
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT user_id FROM room_memberships
+ WHERE room_id = ?
+ AND user_id LIKE ?
+ AND membership = 'join'
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ defer.returnValue(False)
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ defer.returnValue(True)
+
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
@@ -560,6 +457,144 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(joined_hosts)
+ @cached(max_entries=10000)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ]
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ return self.runInteraction("forget_membership", f)
+
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
+
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -573,8 +608,9 @@ class RoomMemberStore(SQLBaseStore):
def add_membership_profile_txn(txn):
sql = ("""
- SELECT stream_ordering, event_id, events.room_id, content
+ SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
+ INNER JOIN event_json USING (event_id)
INNER JOIN room_memberships USING (event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND type = 'm.room.member'
@@ -595,8 +631,9 @@ class RoomMemberStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- content = json.loads(row["content"])
- except:
+ event_json = json.loads(row["json"])
+ content = event_json['content']
+ except Exception:
continue
display_name = content.get("displayname", None)
@@ -635,10 +672,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(result)
- @cached(max_entries=10000, iterable=True)
- def _get_joined_hosts_cache(self, room_id):
- return _JoinedHostsCache(self, room_id)
-
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
@@ -671,7 +704,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
if typ != EventTypes.Member:
continue
@@ -701,7 +734,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
+import simplejson as json
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -14,10 +14,10 @@
import logging
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+import simplejson
-import ujson
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -14,9 +14,9 @@
import logging
-from synapse.storage.prepare_database import get_statements
+import simplejson
-import ujson
+from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 5b7d8d1ab5..ef7ec34346 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from synapse.config.appservice import load_appservices
+from six.moves import range
+
+from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__)
@@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice.
try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
- except:
+ except Exception:
# Maybe we already added the column? Hope so...
pass
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+ user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+import simplejson
+
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
-import logging
-import ujson
-
logger = logging.getLogger(__name__)
@@ -49,7 +50,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-
import logging
-import ujson
+
+import simplejson
+
+from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
index 55ae43f395..9754d3ccfb 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -14,7 +14,6 @@
import time
-
ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
index 3b63a1562d..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py
index 033144341c..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/schema/delta/34/received_txn_purge.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py
index 81948e3431..0ffab10b6f 100644
--- a/synapse/storage/schema/delta/34/sent_txn_purge.py
+++ b/synapse/storage/schema/delta/34/sent_txn_purge.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 20ad8bd5a6..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
DROP_INDICES = """
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
* limitations under the License.
*/
- INSERT into background_updates (update_name, progress_json)
- VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+-- VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
index ea6a18196d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -14,8 +14,8 @@
import logging
-from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
index 4501d90cbb..ee7062abe4 100644
--- a/synapse/storage/schema/delta/43/user_share.sql
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
--- Make sure that we popualte the table initially
+-- Make sure that we populate the table initially
UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..b12f9b2ebf
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,41 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
+-- removed and replaced with 46/local_media_repository_url_idx.sql.
+--
+-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+ url TEXT,
+ response_code INTEGER,
+ etag TEXT,
+ expires_ts BIGINT,
+ og TEXT,
+ media_id TEXT,
+ download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+ SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql
new file mode 100644
index 0000000000..b2333848a0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/group_server.sql
@@ -0,0 +1,167 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups (
+ group_id TEXT NOT NULL,
+ name TEXT, -- the display name of the room
+ avatar_url TEXT,
+ short_description TEXT,
+ long_description TEXT
+);
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
+
+
+-- list of users the group server thinks are joined
+CREATE TABLE group_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
+);
+
+
+CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
+CREATE INDEX groups_users_u_idx ON group_users(user_id);
+
+-- list of users the group server thinks are invited
+CREATE TABLE group_invites (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL
+);
+
+CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
+CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
+
+
+CREATE TABLE group_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
+);
+
+CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
+CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
+
+
+-- Rooms to include in the summary
+CREATE TABLE group_summary_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ room_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
+ UNIQUE (group_id, category_id, room_id, room_order),
+ CHECK (room_order > 0)
+);
+
+CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
+
+
+-- Categories to include in the summary
+CREATE TABLE group_summary_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ cat_order BIGINT NOT NULL,
+ UNIQUE (group_id, category_id, cat_order),
+ CHECK (cat_order > 0)
+);
+
+-- The categories in the group
+CREATE TABLE group_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
+ UNIQUE (group_id, category_id)
+);
+
+-- The users to include in the group summary
+CREATE TABLE group_summary_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ user_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
+);
+
+CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
+
+-- The roles to include in the group summary
+CREATE TABLE group_summary_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ role_order BIGINT NOT NULL,
+ UNIQUE (group_id, role_id, role_order),
+ CHECK (role_order > 0)
+);
+
+
+-- The roles in a groups
+CREATE TABLE group_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
+ UNIQUE (group_id, role_id)
+);
+
+
+-- List of attestations we've given out and need to renew
+CREATE TABLE group_attestations_renewals (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL
+);
+
+CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
+CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
+CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
+
+
+-- List of attestations we've received from remotes and are interested in.
+CREATE TABLE group_attestations_remote (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL,
+ attestation_json TEXT NOT NULL
+);
+
+CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
+CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
+CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+-- The group membership for the HS's users
+CREATE TABLE local_group_membership (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ membership TEXT NOT NULL,
+ is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
+ content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+ stream_id BIGINT NOT NULL,
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ content TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql
new file mode 100644
index 0000000000..e5ddc84df0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/profile_cache.sql
@@ -0,0 +1,28 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A subset of remote users whose profiles we have cached.
+-- Whether a user is in this table or not is defined by the storage function
+-- `is_subscribed_remote_profile_for_user`
+CREATE TABLE remote_profile_cache (
+ user_id TEXT NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ last_check BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
+CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
index bb225dafbf..68c48a89a9 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
+++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,5 +13,5 @@
* limitations under the License.
*/
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('refresh_tokens_device_index', '{}');
+/* we no longer use (or create) the refresh_tokens table */
+DROP TABLE IF EXISTS refresh_tokens;
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
new file mode 100644
index 0000000000..bb307889c1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop the unique constraint on deleted_pushers so that we can just insert
+-- into it rather than upserting.
+
+CREATE TABLE deleted_pushers2 (
+ stream_id BIGINT NOT NULL,
+ app_id TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ user_id TEXT NOT NULL
+);
+
+INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
+ SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
+
+DROP TABLE deleted_pushers;
+ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
+
+-- create the index after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
+
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql
new file mode 100644
index 0000000000..097679bc9a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/group_server.sql
@@ -0,0 +1,32 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups_new (
+ group_id TEXT NOT NULL,
+ name TEXT, -- the display name of the room
+ avatar_url TEXT,
+ short_description TEXT,
+ long_description TEXT,
+ is_public BOOL NOT NULL -- whether non-members can access group APIs
+);
+
+-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
+INSERT INTO groups_new
+ SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
+
+DROP TABLE groups;
+ALTER TABLE groups_new RENAME TO groups;
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
new file mode 100644
index 0000000000..bbfc7f5d1a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will recreate the
+-- local_media_repository_url_idx index.
+--
+-- We do this as a bg update not because it is a particularly onerous
+-- operation, but because we'd like it to be a partial index if possible, and
+-- the background_index_update code will understand whether we are on
+-- postgres or sqlite and behave accordingly.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_url_idx', '{}');
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
new file mode 100644
index 0000000000..cb0d5a2576
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- change the user_directory table to also cover global local user profiles
+-- rather than just profiles within specific rooms.
+
+CREATE TABLE user_directory2 (
+ user_id TEXT NOT NULL,
+ room_id TEXT,
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
+ SELECT user_id, room_id, display_name, avatar_url from user_directory;
+
+DROP TABLE user_directory;
+ALTER TABLE user_directory2 RENAME TO user_directory;
+
+-- create indexes after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql
new file mode 100644
index 0000000000..d9505f8da1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_typos.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this is just embarassing :|
+ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
+
+-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
+-- so is hopefully not going to block anyone else for that long...
+CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
+CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
+DROP INDEX users_in_pubic_room_room_idx;
+DROP INDEX users_in_pubic_room_user_idx;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/47/last_access_media.sql
index 290bd6da86..f505fb22b5 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device.sql
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
* limitations under the License.
*/
-ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
index 34db0cf12b..31d7a817eb 100644
--- a/synapse/storage/schema/delta/23/refresh_tokens.sql
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -1,4 +1,4 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
+/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,9 +13,5 @@
* limitations under the License.
*/
-CREATE TABLE IF NOT EXISTS refresh_tokens(
- id INTEGER PRIMARY KEY,
- token TEXT NOT NULL,
- user_id TEXT NOT NULL,
- UNIQUE (token)
-);
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ notif SMALLINT NOT NULL,
+ highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # if we already have some state groups, we want to start making new
+ # ones with a higher id.
+ cur.execute("SELECT max(id) FROM state_groups")
+ row = cur.fetchone()
+
+ if row[0] is None:
+ start_val = 1
+ else:
+ start_val = row[0] + 1
+
+ cur.execute(
+ "CREATE SEQUENCE state_group_id_seq START WITH %s",
+ (start_val, ),
+ )
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql
new file mode 100644
index 0000000000..5237491506
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_consent.sql
@@ -0,0 +1,18 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record the version of the privacy policy the user has consented to
+ */
+ALTER TABLE users ADD COLUMN consent_version TEXT;
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
new file mode 100644
index 0000000000..9248b0b24a
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_index', '{}');
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql
new file mode 100644
index 0000000000..e9013a6969
--- /dev/null
+++ b/synapse/storage/schema/delta/48/deactivated_users.sql
@@ -0,0 +1,25 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Store any accounts that have been requested to be deactivated.
+ * We part the account from all the rooms its in when its
+ * deactivated. This can take some time and synapse may be restarted
+ * before it completes, so store the user IDs here until the process
+ * is complete.
+ */
+CREATE TABLE users_pending_deactivation (
+ user_id TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+ # remove duplicates from group_users & group_invites tables
+ cur.execute("""
+ DELETE FROM group_users WHERE %s NOT IN (
+ SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+ cur.execute("""
+ DELETE FROM group_invites WHERE %s NOT IN (
+ SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+
+ for statement in get_statements(FIX_INDEXES.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql
new file mode 100644
index 0000000000..ce26eaf0c9
--- /dev/null
+++ b/synapse/storage/schema/delta/48/groups_joinable.sql
@@ -0,0 +1,22 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This isn't a real ENUM because sqlite doesn't support it
+ * and we use a default of NULL for inserted rows and interpret
+ * NULL at the python store level as necessary so that existing
+ * rows are given the correct default policy.
+ */
+ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
new file mode 100644
index 0000000000..14dcf18d73
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
@@ -0,0 +1,20 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record whether we have sent a server notice about consenting to the
+ * privacy policy. Specifically records the version of the policy we sent
+ * a message about.
+ */
+ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT;
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
new file mode 100644
index 0000000000..3dd478196f
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL,
+ device_id TEXT,
+ timestamp BIGINT NOT NULL );
+CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
+CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
new file mode 100644
index 0000000000..3a4ed59b5b
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_only_index', '{}');
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
new file mode 100644
index 0000000000..c93ae47532
--- /dev/null
+++ b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('users_creation_ts', '{}');
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql
new file mode 100644
index 0000000000..5d8641a9ab
--- /dev/null
+++ b/synapse/storage/schema/delta/50/erasure_store.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of users who have requested that their details be erased
+CREATE TABLE erased_users (
+ user_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id);
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
file TEXT NOT NULL,
UNIQUE(version, file)
);
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+ module_name TEXT NOT NULL,
+ file TEXT NOT NULL,
+ UNIQUE(module_name, file)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 8f2b3c4435..d5b5df93e6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,28 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import re
+from collections import namedtuple
+
+from six import string_types
+
+from canonicaljson import json
+
from twisted.internet import defer
-from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import logging
-import re
-import ujson as json
-
+from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
+SearchEntry = namedtuple('SearchEntry', [
+ 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+ 'origin_server_ts',
+])
+
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+ EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, hs):
- super(SearchStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(SearchStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@@ -42,23 +52,35 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order
)
- self.register_background_update_handler(
+
+ # we used to have a background update to turn the GIN index into a
+ # GIST one; we no longer do that (obviously) because we actually want
+ # a GIN index. However, it's possible that some people might still have
+ # the background update queued, so we register a handler to clear the
+ # background update.
+ self.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
- self._background_reindex_gist_search
+ )
+
+ self.register_background_update_handler(
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+ self._background_reindex_gin_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
+ # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+ "SELECT stream_ordering, event_id, room_id, type, json, "
+ " origin_server_ts FROM events"
+ " JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -67,6 +89,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+ # we could stream straight from the results into
+ # store_search_entries_txn with a generator function, but that
+ # would mean having two cursors open on the database at once.
+ # Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
if not rows:
return 0
@@ -79,9 +105,12 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
+ stream_ordering = row["stream_ordering"]
+ origin_server_ts = row["origin_server_ts"]
try:
- content = json.loads(row["content"])
- except:
+ event_json = json.loads(row["json"])
+ content = event_json["content"]
+ except Exception:
continue
if etype == "m.room.message":
@@ -93,35 +122,28 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name":
key = "content.name"
value = content["name"]
+ else:
+ raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
- if not isinstance(value, basestring):
+ if not isinstance(value, string_types):
# If the event body, name or topic isn't a string
# then skip over it
continue
- event_search_rows.append((event_id, room_id, key, value))
-
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, vector)"
- " VALUES (?,?,?,to_tsvector('english', ?))"
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
+ event_search_rows.append(SearchEntry(
+ key=key,
+ value=value,
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ origin_server_ts=origin_server_ts,
+ ))
- for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
- clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -145,25 +167,48 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result)
@defer.inlineCallbacks
- def _background_reindex_gist_search(self, progress, batch_size):
+ def _background_reindex_gin_search(self, progress, batch_size):
+ """This handles old synapses which used GIST indexes, if any;
+ converting them back to be GIN as per the actual schema.
+ """
+
def create_index(conn):
conn.rollback()
- conn.set_session(autocommit=True)
- c = conn.cursor()
- c.execute(
- "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
- " ON event_search USING GIST (vector)"
- )
+ # we have to set autocommit, because postgres refuses to
+ # CREATE INDEX CONCURRENTLY without it.
+ conn.set_session(autocommit=True)
- c.execute("DROP INDEX event_search_fts_idx")
+ try:
+ c = conn.cursor()
- conn.set_session(autocommit=False)
+ # if we skipped the conversion to GIST, we may already/still
+ # have an event_search_fts_idx; unfortunately postgres 9.4
+ # doesn't support CREATE INDEX IF EXISTS so we just catch the
+ # exception and ignore it.
+ import psycopg2
+ try:
+ c.execute(
+ "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+ " ON event_search USING GIN (vector)"
+ )
+ except psycopg2.ProgrammingError as e:
+ logger.warn(
+ "Ignoring error %r when trying to switch from GIST to GIN",
+ e
+ )
+
+ # we should now be able to delete the GIST index.
+ c.execute(
+ "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+ )
+ finally:
+ conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index)
- yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+ yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
defer.returnValue(1)
@defer.inlineCallbacks
@@ -242,6 +287,85 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows)
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),),
+ )
+
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ entry.stream_ordering, entry.origin_server_ts,
+ ) for entry in entries)
+
+ # inserts to a GIN index are normally batched up into a pending
+ # list, and then all committed together once the list gets to a
+ # certain size. The trouble with that is that postgres (pre-9.5)
+ # uses work_mem to determine the length of the list, and work_mem
+ # is typically very large.
+ #
+ # We therefore reduce work_mem while we do the insert.
+ #
+ # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+ # so doesn't suffer the same problem, but changing work_mem will
+ # be harmless)
+ #
+ # Note that we don't need to worry about restoring it on
+ # exception, because exceptions will cause the transaction to be
+ # rolled back, including the effects of the SET command.
+ #
+ # Also: we use SET rather than SET LOCAL because there's lots of
+ # other stuff going on in this transaction, which want to have the
+ # normal work_mem setting.
+
+ txn.execute("SET work_mem='256kB'")
+ txn.executemany(sql, args)
+ txn.execute("RESET work_mem")
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ ) for entry in entries)
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -326,7 +450,7 @@ class SearchStore(BackgroundUpdateStore):
"search_msgs", self.cursor_to_dict, sql, *args
)
- results = filter(lambda row: row["room_id"] in room_ids, results)
+ results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
@@ -407,7 +531,7 @@ class SearchStore(BackgroundUpdateStore):
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
- except:
+ except Exception:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
@@ -481,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args
)
- results = filter(lambda row: row["room_id"] in room_ids, results)
+ results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 67d5d9969a..470212aa2a 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -13,21 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from ._base import SQLBaseStore
+import six
from unpaddedbase64 import encode_base64
+
+from twisted.internet import defer
+
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
+from ._base import SQLBaseStore
+
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
-class SignatureStore(SQLBaseStore):
- """Persistence for event signatures and hashes"""
+class SignatureWorkerStore(SQLBaseStore):
@cached()
def get_event_reference_hash(self, event_id):
- return self._get_event_reference_hashes_txn(event_id)
+ # This is a dummy function to allow get_event_reference_hashes
+ # to use its cache
+ raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
@@ -56,7 +66,7 @@ class SignatureStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(hashes.items())
+ defer.returnValue(list(hashes.items()))
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
@@ -74,6 +84,10 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, ))
return {k: v for k, v in txn}
+
+class SignatureStore(SignatureWorkerStore):
+ """Persistence for event signatures and hashes"""
+
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
Args:
@@ -87,7 +101,7 @@ class SignatureStore(SQLBaseStore):
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
- "hash": buffer(ref_hash_bytes),
+ "hash": db_binary_type(ref_hash_bytes),
})
self._simple_insert_many_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5673e4aa96..89a05c4618 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,16 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches import intern_string
-from synapse.util.stringutils import to_ascii
-from synapse.storage.engines import PostgresEngine
+import logging
+from collections import namedtuple
+
+from six import iteritems, itervalues
+from six.moves import range
from twisted.internet import defer
-from collections import namedtuple
-import logging
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.stringutils import to_ascii
+
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -40,45 +46,19 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateStore(SQLBaseStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
-
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
+class StateGroupWorkerStore(SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- def __init__(self, hs):
- super(StateStore, self).__init__(hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
- )
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
- )
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
+ def __init__(self, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
)
@cached(max_entries=100000, iterable=True)
@@ -158,12 +138,26 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@defer.inlineCallbacks
+ def get_state_ids_for_group(self, state_group):
+ """Get the state IDs for the given state group
+
+ Args:
+ state_group (int)
+
+ Returns:
+ Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = yield self._get_state_for_groups((state_group,))
+
+ defer.returnValue(group_to_state[state_group])
+
+ @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
@@ -176,199 +170,27 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in group_to_ids.itervalues()
- for ev_id in group_ids.itervalues()
+ ev_id for group_ids in itervalues(group_to_ids)
+ for ev_id in itervalues(group_ids)
],
get_prev_content=False
)
defer.returnValue({
group: [
- state_event_map[v] for v in event_id_map.itervalues()
+ state_event_map[v] for v in itervalues(event_id_map)
if v in state_event_map
]
- for group, event_id_map in group_to_ids.iteritems()
+ for group, event_id_map in iteritems(group_to_ids)
})
- def _have_persisted_state_group_txn(self, txn, state_group):
- txn.execute(
- "SELECT count(*) FROM state_groups WHERE id = ?",
- (state_group,)
- )
- row = txn.fetchone()
- return row and row[0]
-
- def _store_mult_state_groups_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- if context.current_state_ids is None:
- # AFAIK, this can never happen
- logger.error(
- "Non-outlier event %s had current_state_ids==None",
- event.event_id)
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
-
- state_groups[event.event_id] = context.state_group
-
- if self._have_persisted_state_group_txn(txn, context.state_group):
- continue
-
- self._simple_insert_txn(
- txn,
- table="state_groups",
- values={
- "id": context.state_group,
- "room_id": event.room_id,
- "event_id": event.event_id,
- },
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if context.prev_group:
- is_in_db = self._simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": context.prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (context.prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(
- txn, context.prev_group
- )
- if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": context.state_group,
- "prev_state_group": context.prev_group,
- },
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": context.state_group,
- "room_id": event.room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in context.delta_ids.iteritems()
- ],
- )
- else:
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": context.state_group,
- "room_id": event.room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in context.current_state_ids.iteritems()
- ],
- )
-
- # Prefill the state group cache with this group.
- # It's fine to use the sequence like this as the state group map
- # is immutable. (If the map wasn't immutable then this prefill could
- # race with another update)
- txn.call_after(
- self._state_group_cache.update,
- self._state_group_cache.sequence,
- key=context.state_group,
- value=dict(context.current_state_ids),
- full=True,
- )
-
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {
- "state_group": state_group_id,
- "event_id": event_id,
- }
- for event_id, state_group_id in state_groups.iteritems()
- ],
- )
-
- for event_id, state_group_id in state_groups.iteritems():
- txn.call_after(
- self._get_state_group_for_event.prefill,
- (event_id,), state_group_id
- )
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = ("""
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """)
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""
results = {}
- chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+ chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
@@ -422,6 +244,9 @@ class StateStore(SQLBaseStore):
(
"AND type = ? AND state_key = ?",
(etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
)
for etype, state_key in types
]
@@ -441,10 +266,19 @@ class StateStore(SQLBaseStore):
key = (typ, state_key)
results[group][key] = event_id
else:
+ where_args = []
+ where_clauses = []
+ wildcard_types = False
if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
+ for typ in types:
+ if typ[1] is None:
+ where_clauses.append("(type = ?)")
+ where_args.append(typ[0])
+ wildcard_types = True
+ else:
+ where_clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend([typ[0], typ[1]])
+ where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@@ -461,7 +295,7 @@ class StateStore(SQLBaseStore):
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
- args.extend(i for typ in types for i in typ)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
@@ -474,9 +308,17 @@ class StateStore(SQLBaseStore):
if (typ, state_key) not in results[group]
)
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if types is not None and len(results[group]) == len(types):
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ types is not None and
+ not wildcard_types and
+ len(results[group]) == len(types)
+ ):
break
next_group = self._simple_select_one_onecol_txn(
@@ -509,21 +351,21 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
- [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
+ [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in group_to_state[group].iteritems()
+ for k, v in iteritems(group_to_state[group])
if v in state_event_map
}
- for event_id, group in event_to_groups.iteritems()
+ for event_id, group in iteritems(event_to_groups)
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -546,12 +388,12 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in event_to_groups.iteritems()
+ for event_id, group in iteritems(event_to_groups)
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -665,7 +507,7 @@ class StateStore(SQLBaseStore):
got_all = is_all or not missing_types
return {
- k: v for k, v in state_dict_ids.iteritems()
+ k: v for k, v in iteritems(state_dict_ids)
if include(k[0], k[1])
}, missing_types, got_all
@@ -685,10 +527,23 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
- """Given list of groups returns dict of group -> list of state events
- with matching types. `types` is a list of `(type, state_key)`, where
- a `state_key` of None matches all state_keys. If `types` is None then
- all events are returned.
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
"""
if types:
types = frozenset(types)
@@ -697,7 +552,7 @@ class StateStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache(
- group, types
+ group, types,
)
results[group] = state_dict_ids
@@ -718,32 +573,266 @@ class StateStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
+ # the DictionaryCache knows if it has *all* the state, but
+ # does not know if it has all of the keys of a particular type,
+ # which makes wildcard lookups expensive unless we have a complete
+ # cache. Hence, if we are doing a wildcard lookup, populate the
+ # cache fully so that we can do an efficient lookup next time.
+
+ if types and any(k is None for (t, k) in types):
+ types_to_fetch = None
+ else:
+ types_to_fetch = types
+
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types
+ missing_groups, types_to_fetch,
)
- # Now we want to update the cache with all the things we fetched
- # from the database.
- for group, group_state_dict in group_to_state_dict.iteritems():
+ for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
- state_dict.update(
- ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
- for k, v in group_state_dict.iteritems()
- )
-
+ # update the result, filtering by `types`.
+ if types:
+ for k, v in iteritems(group_state_dict):
+ (typ, _) = k
+ if k in types or (typ, None) in types:
+ state_dict[k] = v
+ else:
+ state_dict.update(group_state_dict)
+
+ # update the cache with all the things we fetched from the
+ # database.
self._state_group_cache.update(
cache_seq_num,
key=group,
- value=state_dict,
- full=(types is None),
- known_absent=types,
+ value=group_state_dict,
+ fetched_keys=types_to_fetch,
)
defer.returnValue(results)
- def get_next_state_group(self):
- return self._state_groups_id_gen.get_next()
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "id": state_group,
+ "room_id": room_id,
+ "event_id": event_id,
+ },
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(
+ txn, prev_group
+ )
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self._simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_ids)
+ ],
+ )
+ else:
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(current_state_ids)
+ ],
+ )
+
+ # Prefill the state group cache with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_state_ids),
+ )
+
+ return state_group
+
+ return self.runInteraction("store_state_group", _store_state_group_txn)
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = ("""
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """)
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+ def __init__(self, db_conn, hs):
+ super(StateStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME,
+ self._background_index_state,
+ )
+ self.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+
+ def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {
+ "state_group": state_group_id,
+ "event_id": event_id,
+ }
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
+ )
+
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
@@ -767,7 +856,7 @@ class StateStore(SQLBaseStore):
def reindex_txn(txn):
new_last_state_group = last_state_group
- for count in xrange(batch_size):
+ for count in range(batch_size):
txn.execute(
"SELECT id, room_id FROM state_groups"
" WHERE ? < id AND id <= ?"
@@ -825,7 +914,7 @@ class StateStore(SQLBaseStore):
# of keys
delta_state = {
- key: value for key, value in curr_state.iteritems()
+ key: value for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
@@ -865,7 +954,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_state.iteritems()
+ for key, state_id in iteritems(delta_state)
],
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index dddd5fc0e7..66856342f0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -33,17 +33,20 @@ what sort order was used:
and stream ordering columns respectively.
"""
-from twisted.internet import defer
+import abc
+import logging
+from collections import namedtuple
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
-from synapse.types import RoomStreamToken
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from six.moves import range
-import logging
+from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.events import EventsWorkerStore
+from synapse.types import RoomStreamToken
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -55,6 +58,12 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+# Used as return values for pagination APIs
+_EventDictReturn = namedtuple("_EventDictReturn", (
+ "event_id", "topological_ordering", "stream_ordering",
+))
+
+
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
@@ -143,81 +152,41 @@ def filter_to_clause(event_filter):
return " AND ".join(clauses), args
-class StreamStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
- # NB this lives here instead of appservice.py so we can reuse the
- # 'private' StreamToken class in this file.
- if limit:
- limit = max(limit, MAX_STREAM_SIZE)
- else:
- limit = MAX_STREAM_SIZE
-
- # From and to keys should be integers from ordering.
- from_id = RoomStreamToken.parse_stream_token(from_key)
- to_id = RoomStreamToken.parse_stream_token(to_key)
-
- if from_key == to_key:
- defer.returnValue(([], to_key))
- return
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+ which can be called in the initializer.
+ """
- # select all the events between from/to with a sensible limit
- sql = (
- "SELECT e.event_id, e.room_id, e.type, s.state_key, "
- "e.stream_ordering FROM events AS e "
- "LEFT JOIN state_events as s ON "
- "e.event_id = s.event_id "
- "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
- "ORDER BY stream_ordering ASC LIMIT %(limit)d "
- ) % {
- "limit": limit
- }
+ __metaclass__ = abc.ABCMeta
- def f(txn):
- # pull out all the events between the tokens
- txn.execute(sql, (from_id.stream, to_id.stream,))
- rows = self.cursor_to_dict(txn)
-
- # Logic:
- # - We want ALL events which match the AS room_id regex
- # - We want ALL events which match the rooms represented by the AS
- # room_alias regex
- # - We want ALL events for rooms that AS users have joined.
- # This is currently supported via get_app_service_rooms (which is
- # used for the Notifier listener rooms). We can't reasonably make a
- # SQL query for these room IDs, so we'll pull all the events between
- # from/to and filter in python.
- rooms_for_as = self._get_app_service_rooms_txn(txn, service)
- room_ids_for_as = [r.room_id for r in rooms_for_as]
-
- def app_service_interested(row):
- if row["room_id"] in room_ids_for_as:
- return True
-
- if row["type"] == EventTypes.Member:
- if service.is_interested_in_user(row.get("state_key")):
- return True
- return False
-
- return [r for r in rows if app_service_interested(r)]
-
- rows = yield self.runInteraction("get_appservice_room_stream", f)
+ def __init__(self, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(db_conn, hs)
- ret = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
+ events_max = self.get_room_max_stream_ordering()
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._stream_order_on_start = self.get_room_max_stream_ordering()
- if rows:
- key = "s%d" % max(r["stream_ordering"] for r in rows)
- else:
- # Assume we didn't get anything because there was nothing to
- # get.
- key = to_key
+ @abc.abstractmethod
+ def get_room_max_stream_ordering(self):
+ raise NotImplementedError()
- defer.returnValue((ret, key))
+ @abc.abstractmethod
+ def get_room_min_stream_ordering(self):
+ raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -233,13 +202,14 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
- res = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.get_room_events_stream_for_room)(
+ for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
+ res = yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(
+ self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
- ]))
+ ], consumeErrors=True))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -261,54 +231,55 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):
- # Note: If from_key is None then we return in topological order. This
- # is because in that case we're using this as a "get the last few messages
- # in a room" function, rather than "get new messages since last sync"
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ """Get new room events in stream ordering since `from_key`.
+
+ Args:
+ room_id (str)
+ from_key (str): Token from which no events are returned before
+ to_key (str): Token from which no events are returned after. (This
+ is typically the current stream token)
+ limit (int): Maximum number of events to return
+ order (str): Either "DESC" or "ASC". Determines which events are
+ returned when the result is limited. If "DESC" then the most
+ recent `limit` events are returned, otherwise returns the
+ oldest `limit` events.
+
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
+ events (in ascending order) and the token from the start of
+ the chunk of events returned.
+ """
if from_key == to_key:
defer.returnValue(([], from_key))
- if from_id:
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
-
- if not has_changed:
- defer.returnValue(([], from_key))
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
- def f(txn):
- if from_id is not None:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
- else:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering <= ?"
- " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
- ) % (order, order,)
- txn.execute(sql, (room_id, to_id, limit))
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
- rows = self.cursor_to_dict(txn)
+ if not has_changed:
+ defer.returnValue(([], from_key))
+ def f(txn):
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering %s LIMIT ?"
+ ) % (order,)
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -318,7 +289,7 @@ class StreamStore(SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r["stream_ordering"] for r in rows)
+ key = "s%d" % min(r.stream_ordering for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -328,10 +299,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
@@ -345,34 +313,24 @@ class StreamStore(SQLBaseStore):
defer.returnValue([])
def f(txn):
- if from_id is not None:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id,))
- else:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- )
- txn.execute(sql, (user_id, to_id,))
- rows = self.cursor_to_dict(txn)
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -381,96 +339,28 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
- args = [False, room_id]
- if direction == 'b':
- order = "DESC"
- bounds = upper_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, lower_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
- else:
- order = "ASC"
- bounds = lower_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, upper_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
-
- filter_clause, filter_args = filter_to_clause(event_filter)
-
- if filter_clause:
- bounds += " AND " + filter_clause
- args.extend(filter_args)
-
- if int(limit) > 0:
- args.append(int(limit))
- limit_str = " LIMIT ?"
- else:
- limit_str = ""
-
- sql = (
- "SELECT * FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s %(limit)s"
- ) % {
- "bounds": bounds,
- "order": order,
- "limit": limit_str
- }
-
- def f(txn):
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
-
- if rows:
- topo = rows[-1]["topological_ordering"]
- toke = rows[-1]["stream_ordering"]
- if direction == 'b':
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- toke -= 1
- next_token = str(RoomStreamToken(topo, toke))
- else:
- # TODO (erikj): We should work out what to do here instead.
- next_token = to_key if to_key else from_key
+ def get_recent_events_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
- return rows, next_token,
-
- rows, token = yield self.runInteraction("paginate_room_events", f)
-
- events = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
- )
-
- self._set_before_and_after(events, rows)
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
- defer.returnValue((events, token))
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ events and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
rows, token = yield self.get_recent_event_ids_for_room(
- room_id, limit, end_token, from_token
+ room_id, limit, end_token,
)
logger.debug("stream before")
events = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
logger.debug("stream after")
@@ -479,59 +369,62 @@ class StreamStore(SQLBaseStore):
defer.returnValue((events, token))
- @cached(num_args=4)
- def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None):
- end_token = RoomStreamToken.parse_stream_token(end_token)
+ @defer.inlineCallbacks
+ def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
- if from_token is None:
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
- else:
- from_token = RoomStreamToken.parse_stream_token(from_token)
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering > ?"
- " AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
- def get_recent_events_for_room_txn(txn):
- if from_token is None:
- txn.execute(sql, (room_id, end_token.stream, False, limit,))
- else:
- txn.execute(sql, (
- room_id, from_token.stream, end_token.stream, False, limit
- ))
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ _EventDictReturn and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
+ # Allow a zero limit here, and no-op.
+ if limit == 0:
+ defer.returnValue(([], end_token))
- rows = self.cursor_to_dict(txn)
+ end_token = RoomStreamToken.parse(end_token)
- rows.reverse() # As we selected with reverse ordering
+ rows, token = yield self.runInteraction(
+ "get_recent_event_ids_for_room", self._paginate_room_events_txn,
+ room_id, from_token=end_token, limit=limit,
+ )
- if rows:
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # since we are going backwards so we subtract one from the
- # stream part.
- topo = rows[0]["topological_ordering"]
- toke = rows[0]["stream_ordering"] - 1
- start_token = str(RoomStreamToken(topo, toke))
+ # We want to return the results in ascending order.
+ rows.reverse()
- token = (start_token, str(end_token))
- else:
- token = (str(end_token), str(end_token))
+ defer.returnValue((rows, token))
+
+ def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+ """Gets details of the first event in a room at or after a stream ordering
+
+ Args:
+ room_id (str):
+ stream_ordering (int):
- return rows, token
+ Returns:
+ Deferred[(int, int, str)]:
+ (stream ordering, topological ordering, event_id)
+ """
+ def _f(txn):
+ sql = (
+ "SELECT stream_ordering, topological_ordering, event_id"
+ " FROM events"
+ " WHERE room_id = ? AND stream_ordering >= ?"
+ " AND NOT outlier"
+ " ORDER BY stream_ordering"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (room_id, stream_ordering, ))
+ return txn.fetchone()
return self.runInteraction(
- "get_recent_events_for_room", get_recent_events_for_room_txn
+ "get_room_event_after_stream_ordering", _f,
)
@defer.inlineCallbacks
@@ -542,7 +435,7 @@ class StreamStore(SQLBaseStore):
`room_id` causes it to return the current room specific topological
token.
"""
- token = yield self._stream_id_gen.get_current_token()
+ token = yield self.get_room_max_stream_ordering()
if room_id is None:
defer.returnValue("s%d" % (token,))
else:
@@ -552,12 +445,6 @@ class StreamStore(SQLBaseStore):
)
defer.returnValue("t%d-%d" % (topo, token))
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
-
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
@@ -615,10 +502,20 @@ class StreamStore(SQLBaseStore):
@staticmethod
def _set_before_and_after(events, rows, topo_order=True):
+ """Inserts ordering information to events' internal metadata from
+ the DB rows.
+
+ Args:
+ events (list[FrozenEvent])
+ rows (list[_EventDictReturn])
+ topo_order (bool): Whether the events were ordered topologically
+ or by stream ordering. If true then all rows should have a non
+ null topological_ordering.
+ """
for event, row in zip(events, rows):
- stream = row["stream_ordering"]
- if topo_order:
- topo = event.depth
+ stream = row.stream_ordering
+ if topo_order and row.topological_ordering:
+ topo = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
@@ -690,87 +587,27 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- token = RoomStreamToken(
- results["topological_ordering"],
+ # Paginating backwards includes the event at the token, but paginating
+ # forward doesn't.
+ before_token = RoomStreamToken(
+ results["topological_ordering"] - 1,
results["stream_ordering"],
)
- if isinstance(self.database_engine, Sqlite3Engine):
- # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
- # So we give pass it to SQLite3 as the UNION ALL of the two queries.
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering < ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- )
- before_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- before_limit,
- )
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering > ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- )
- after_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- after_limit,
- )
- else:
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- ) % (upper_bound(token, self.database_engine, inclusive=False),)
-
- before_args = (room_id, before_limit)
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
- after_args = (room_id, after_limit)
-
- txn.execute(query_before, before_args)
-
- rows = self.cursor_to_dict(txn)
- events_before = [r["event_id"] for r in rows]
-
- if rows:
- start_token = str(RoomStreamToken(
- rows[0]["topological_ordering"],
- rows[0]["stream_ordering"] - 1,
- ))
- else:
- start_token = str(RoomStreamToken(
- token.topological,
- token.stream - 1,
- ))
-
- txn.execute(query_after, after_args)
+ after_token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
+ )
- rows = self.cursor_to_dict(txn)
- events_after = [r["event_id"] for r in rows]
+ rows, start_token = self._paginate_room_events_txn(
+ txn, room_id, before_token, direction='b', limit=before_limit,
+ )
+ events_before = [r.event_id for r in rows]
- if rows:
- end_token = str(RoomStreamToken(
- rows[-1]["topological_ordering"],
- rows[-1]["stream_ordering"],
- ))
- else:
- end_token = str(token)
+ rows, end_token = self._paginate_room_events_txn(
+ txn, room_id, after_token, direction='f', limit=after_limit,
+ )
+ events_after = [r.event_id for r in rows]
return {
"before": {
@@ -832,3 +669,139 @@ class StreamStore(SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+ def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
+
+ Args:
+ txn
+ room_id (str)
+ from_token (RoomStreamToken): The token used to stream from
+ to_token (RoomStreamToken|None): A token which if given limits the
+ results to only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
+
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
+ as a list of _EventDictReturn and a token that points to the end
+ of the result set.
+ """
+
+ assert int(limit) >= 0
+
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ args = [False, room_id]
+ if direction == 'b':
+ order = "DESC"
+ bounds = upper_bound(
+ from_token, self.database_engine
+ )
+ if to_token:
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ to_token, self.database_engine
+ ))
+ else:
+ order = "ASC"
+ bounds = lower_bound(
+ from_token, self.database_engine
+ )
+ if to_token:
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ to_token, self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
+
+ args.append(int(limit))
+
+ sql = (
+ "SELECT event_id, topological_ordering, stream_ordering"
+ " FROM events"
+ " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+ " ORDER BY topological_ordering %(order)s,"
+ " stream_ordering %(order)s LIMIT ?"
+ ) % {
+ "bounds": bounds,
+ "order": order,
+ }
+
+ txn.execute(sql, args)
+
+ rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+
+ if rows:
+ topo = rows[-1].topological_ordering
+ toke = rows[-1].stream_ordering
+ if direction == 'b':
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ toke -= 1
+ next_token = RoomStreamToken(topo, toke)
+ else:
+ # TODO (erikj): We should work out what to do here instead.
+ next_token = to_token if to_token else from_token
+
+ return rows, str(next_token),
+
+ @defer.inlineCallbacks
+ def paginate_room_events(self, room_id, from_key, to_key=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
+
+ Args:
+ room_id (str)
+ from_key (str): The token used to stream from
+ to_key (str|None): A token which if given limits the results to
+ only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return. Zero or less
+ means no limit.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
+
+ Returns:
+ tuple[list[dict], str]: Returns the results as a list of dicts and
+ a token that points to the end of the result set. The dicts have
+ the keys "event_id", "topological_ordering" and "stream_orderign".
+ """
+
+ from_key = RoomStreamToken.parse(from_key)
+ if to_key:
+ to_key = RoomStreamToken.parse(to_key)
+
+ rows, token = yield self.runInteraction(
+ "paginate_room_events", self._paginate_room_events_txn,
+ room_id, from_key, to_key, direction, limit, event_filter,
+ )
+
+ events = yield self._get_events(
+ [r.event_id for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(events, rows)
+
+ defer.returnValue((events, token))
+
+
+class StreamStore(StreamWorkerStore):
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..0f657b2bd3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,25 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-from twisted.internet import defer
-
-import ujson as json
import logging
-logger = logging.getLogger(__name__)
+from six.moves import range
+from canonicaljson import json
-class TagsStore(SQLBaseStore):
- def get_max_account_data_stream_id(self):
- """Get the current max stream id for the private user data stream
+from twisted.internet import defer
+
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
- Returns:
- A deferred int.
- """
- return self._account_data_id_gen.get_current_token()
+class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@@ -104,7 +101,7 @@ class TagsStore(SQLBaseStore):
batch_size = 50
results = []
- for i in xrange(0, len(tag_ids), batch_size):
+ for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
@@ -170,6 +167,8 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
+
+class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 809fdd311f..c3bc94f56d 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,17 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+import logging
+from collections import namedtuple
+
+import six
+
+from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
-from canonicaljson import encode_canonical_json
+from synapse.util.caches.descriptors import cached
-from collections import namedtuple
+from ._base import SQLBaseStore
-import logging
-import ujson as json
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
logger = logging.getLogger(__name__)
@@ -46,8 +54,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, hs):
- super(TransactionStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
+ super(TransactionStore, self).__init__(db_conn, hs)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
@@ -110,7 +118,7 @@ class TransactionStore(SQLBaseStore):
"transaction_id": transaction_id,
"origin": origin,
"response_code": code,
- "response_json": buffer(encode_canonical_json(response_dict)),
+ "response_json": db_binary_type(encode_canonical_json(response_dict)),
"ts": self._clock.time_msec(),
},
or_ignore=True,
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2a4db3f03c..a8781b0e5d 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -13,17 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
+import re
-from ._base import SQLBaseStore
+from six import iteritems
+
+from twisted.internet import defer
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-import re
-import logging
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -63,7 +65,7 @@ class UserDirectoryStore(SQLBaseStore):
user_ids (list(str)): Users to add
"""
yield self._simple_insert_many(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
values=[
{
"user_id": user_id,
@@ -100,7 +102,7 @@ class UserDirectoryStore(SQLBaseStore):
user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
profile.display_name,
)
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
@@ -112,7 +114,7 @@ class UserDirectoryStore(SQLBaseStore):
user_id,
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
)
- for user_id, p in users_with_profile.iteritems()
+ for user_id, p in iteritems(users_with_profile)
)
else:
# This should be unreachable.
@@ -130,7 +132,7 @@ class UserDirectoryStore(SQLBaseStore):
"display_name": profile.display_name,
"avatar_url": profile.avatar_url,
}
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
]
)
for user_id in users_with_profile:
@@ -164,7 +166,7 @@ class UserDirectoryStore(SQLBaseStore):
)
if isinstance(self.database_engine, PostgresEngine):
- # We weight the loclpart most highly, then display name and finally
+ # We weight the localpart most highly, then display name and finally
# server name
if new_entry:
sql = """
@@ -219,7 +221,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def update_user_in_public_user_list(self, user_id, room_id):
yield self._simple_update_one(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_public_user_list",
@@ -240,7 +242,7 @@ class UserDirectoryStore(SQLBaseStore):
)
self._simple_delete_txn(
txn,
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"user_id": user_id},
)
txn.call_after(
@@ -256,18 +258,18 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id):
yield self._simple_delete(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"user_id": user_id},
desc="remove_from_user_in_public_room",
)
self.get_user_in_public_room.invalidate((user_id,))
def get_users_in_public_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory becuase they're
+ """Get all user_ids that are in the room directory because they're
in the given room_id
"""
return self._simple_select_onecol(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_public_due_to_room",
@@ -275,7 +277,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory becuase they're
+ """Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_dir = yield self._simple_select_onecol(
@@ -286,7 +288,7 @@ class UserDirectoryStore(SQLBaseStore):
)
user_ids_pub = yield self._simple_select_onecol(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
@@ -317,6 +319,16 @@ class UserDirectoryStore(SQLBaseStore):
rows = yield self._execute("get_all_rooms", None, sql)
defer.returnValue([room_id for room_id, in rows])
+ @defer.inlineCallbacks
+ def get_all_local_users(self):
+ """Get all local users
+ """
+ sql = """
+ SELECT name FROM users
+ """
+ rows = yield self._execute("get_all_local_users", None, sql)
+ defer.returnValue([name for name, in rows])
+
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
"""Insert entries into the users_who_share_rooms table. The first
user should be a local user.
@@ -514,7 +526,7 @@ class UserDirectoryStore(SQLBaseStore):
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
- txn.execute("DELETE FROM users_in_pubic_room")
+ txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
txn.call_after(self.get_user_in_public_room.invalidate_all)
@@ -537,7 +549,7 @@ class UserDirectoryStore(SQLBaseStore):
@cached()
def get_user_in_public_room(self, user_id):
return self._simple_select_one(
- table="users_in_pubic_room",
+ table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcols=("room_id",),
allow_none=True,
@@ -629,6 +641,25 @@ class UserDirectoryStore(SQLBaseStore):
]
}
"""
+
+ if self.hs.config.user_directory_search_all_users:
+ # make s.user_id null to keep the ordering algorithm happy
+ join_clause = """
+ CROSS JOIN (SELECT NULL as user_id) AS s
+ """
+ join_args = ()
+ where_clause = "1=1"
+ else:
+ join_clause = """
+ LEFT JOIN users_in_public_rooms AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ """
+ join_args = (user_id,)
+ where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -638,16 +669,12 @@ class UserDirectoryStore(SQLBaseStore):
# The array of numbers are the weights for the various part of the
# search: (domain, _, display name, localpart)
sql = """
- SELECT d.user_id, display_name, avatar_url
+ SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
- LEFT JOIN users_in_pubic_room AS p USING (user_id)
- LEFT JOIN (
- SELECT other_user_id AS user_id FROM users_who_share_rooms
- WHERE user_id = ? AND share_private
- ) AS s USING (user_id)
+ %s
WHERE
- (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ %s
AND vector @@ to_tsquery('english', ?)
ORDER BY
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -671,30 +698,26 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """
- args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ """ % (join_clause, where_clause)
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
sql = """
- SELECT d.user_id, display_name, avatar_url
+ SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
- LEFT JOIN users_in_pubic_room AS p USING (user_id)
- LEFT JOIN (
- SELECT other_user_id AS user_id FROM users_who_share_rooms
- WHERE user_id = ? AND share_private
- ) AS s USING (user_id)
+ %s
WHERE
- (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ %s
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """
- args = (user_id, search_query, limit + 1)
+ """ % (join_clause, where_clause)
+ args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -723,7 +746,7 @@ def _parse_query_sqlite(search_term):
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+ return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
def _parse_query_postgres(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
new file mode 100644
index 0000000000..be013f4427
--- /dev/null
+++ b/synapse/storage/user_erasure_store.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import operator
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class UserErasureWorkerStore(SQLBaseStore):
+ @cached()
+ def is_user_erased(self, user_id):
+ """
+ Check if the given user id has requested erasure
+
+ Args:
+ user_id (str): full user id to check
+
+ Returns:
+ Deferred[bool]: True if the user has requested erasure
+ """
+ return self._simple_select_onecol(
+ table="erased_users",
+ keyvalues={"user_id": user_id},
+ retcol="1",
+ desc="is_user_erased",
+ ).addCallback(operator.truth)
+
+ @cachedList(
+ cached_method_name="is_user_erased",
+ list_name="user_ids",
+ inlineCallbacks=True,
+ )
+ def are_users_erased(self, user_ids):
+ """
+ Checks which users in a list have requested erasure
+
+ Args:
+ user_ids (iterable[str]): full user id to check
+
+ Returns:
+ Deferred[dict[str, bool]]:
+ for each user, whether the user has requested erasure.
+ """
+ # this serves the dual purpose of (a) making sure we can do len and
+ # iterate it multiple times, and (b) avoiding duplicates.
+ user_ids = tuple(set(user_ids))
+
+ def _get_erased_users(txn):
+ txn.execute(
+ "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
+ ",".join("?" * len(user_ids))
+ ),
+ user_ids,
+ )
+ return set(r[0] for r in txn)
+
+ erased_users = yield self.runInteraction(
+ "are_users_erased", _get_erased_users,
+ )
+ res = dict((u, u in erased_users) for u in user_ids)
+ defer.returnValue(res)
+
+
+class UserErasureStore(UserErasureWorkerStore):
+ def mark_user_erased(self, user_id):
+ """Indicate that user_id wishes their message history to be erased.
+
+ Args:
+ user_id (str): full user_id to be erased
+ """
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute(
+ "SELECT 1 FROM erased_users WHERE user_id = ?",
+ (user_id, )
+ )
+ if txn.fetchone():
+ return
+
+ # they are not already there: do the insert.
+ txn.execute(
+ "INSERT INTO erased_users (user_id) VALUES (?)",
+ (user_id, )
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.is_user_erased, (user_id,)
+ )
+ return self.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 95031dc9ec..d6160d5e4d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import deque
import contextlib
import threading
+from collections import deque
class IdGenerator(object):
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 4f089bfb94..451e4fa441 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
-from synapse.types import StreamToken
-
import logging
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer, parse_string
+from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -57,48 +57,33 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None):
- def get_param(name, default=None):
- lst = request.args.get(name, [])
- if len(lst) > 1:
- raise SynapseError(
- 400, "%s must be specified only once" % (name,)
- )
- elif len(lst) == 1:
- return lst[0]
- else:
- return default
-
- direction = get_param("dir", 'f')
- if direction not in ['f', 'b']:
- raise SynapseError(400, "'dir' parameter is invalid.")
-
- from_tok = get_param("from")
- to_tok = get_param("to")
+ direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
+
+ from_tok = parse_string(request, "from")
+ to_tok = parse_string(request, "to")
try:
if from_tok == "END":
from_tok = None # For backwards compat.
elif from_tok:
from_tok = StreamToken.from_string(from_tok)
- except:
+ except Exception:
raise SynapseError(400, "'from' paramater is invalid")
try:
if to_tok:
to_tok = StreamToken.from_string(to_tok)
- except:
+ except Exception:
raise SynapseError(400, "'to' paramater is invalid")
- limit = get_param("limit", None)
- if limit is not None and not limit.isdigit():
- raise SynapseError(400, "'limit' parameter must be an integer.")
+ limit = parse_integer(request, "limit", default=default_limit)
- if limit is None:
- limit = default_limit
+ if limit and limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
- except:
+ except Exception:
logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.")
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..e5220132a3 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -15,13 +15,12 @@
from twisted.internet import defer
-from synapse.types import StreamToken
-
+from synapse.handlers.account_data import AccountDataEventSource
from synapse.handlers.presence import PresenceEventSource
+from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
-from synapse.handlers.receipts import ReceiptEventSource
-from synapse.handlers.account_data import AccountDataEventSource
+from synapse.types import StreamToken
class EventSources(object):
@@ -45,6 +44,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -65,6 +65,7 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
@@ -73,6 +74,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -93,5 +95,6 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 111948540d..08f058f714 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -12,26 +12,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import string
+from collections import namedtuple
from synapse.api.errors import SynapseError
-from collections import namedtuple
-
-Requester = namedtuple("Requester", [
+class Requester(namedtuple("Requester", [
"user", "access_token_id", "is_guest", "device_id", "app_service",
-])
-"""
-Represents the user making a request
+])):
+ """
+ Represents the user making a request
-Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
- request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
-"""
+ Attributes:
+ user (UserID): id of the user making the request
+ access_token_id (int|None): *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest (bool): True if the user making this request is a guest user
+ device_id (str|None): device_id which was set at authentication time
+ app_service (ApplicationService|None): the AS requesting on behalf of the user
+ """
+
+ def serialize(self):
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Returns:
+ dict
+ """
+ return {
+ "user_id": self.user.to_string(),
+ "access_token_id": self.access_token_id,
+ "is_guest": self.is_guest,
+ "device_id": self.device_id,
+ "app_server_id": self.app_service.id if self.app_service else None,
+ }
+
+ @staticmethod
+ def deserialize(store, input):
+ """Converts a dict that was produced by `serialize` back into a
+ Requester.
+
+ Args:
+ store (DataStore): Used to convert AS ID to AS object
+ input (dict): A dict produced by `serialize`
+
+ Returns:
+ Requester
+ """
+ appservice = None
+ if input["app_server_id"]:
+ appservice = store.get_app_service_by_id(input["app_server_id"])
+
+ return Requester(
+ user=UserID.from_string(input["user_id"]),
+ access_token_id=input["access_token_id"],
+ is_guest=input["is_guest"],
+ device_id=input["device_id"],
+ app_service=appservice,
+ )
def create_requester(user_id, access_token_id=None, is_guest=False,
@@ -126,14 +165,10 @@ class DomainSpecificString(
try:
cls.from_string(s)
return True
- except:
+ except Exception:
return False
- __str__ = to_string
-
- @classmethod
- def create(cls, localpart, domain,):
- return cls(localpart=localpart, domain=domain)
+ __repr__ = to_string
class UserID(DomainSpecificString):
@@ -156,6 +191,43 @@ class EventID(DomainSpecificString):
SIGIL = "$"
+class GroupID(DomainSpecificString):
+ """Structure representing a group ID."""
+ SIGIL = "+"
+
+ @classmethod
+ def from_string(cls, s):
+ group_id = super(GroupID, cls).from_string(s)
+ if not group_id.localpart:
+ raise SynapseError(
+ 400,
+ "Group ID cannot be empty",
+ )
+
+ if contains_invalid_mxid_characters(group_id.localpart):
+ raise SynapseError(
+ 400,
+ "Group ID can only contain characters a-z, 0-9, or '=_-./'",
+ )
+
+ return group_id
+
+
+mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
+
+
+def contains_invalid_mxid_characters(localpart):
+ """Check for characters not allowed in an mxid or groupid localpart
+
+ Args:
+ localpart (basestring): the localpart to be checked
+
+ Returns:
+ bool: True if there are any naughty characters
+ """
+ return any(c not in mxid_localpart_allowed_characters for c in localpart)
+
+
class StreamToken(
namedtuple("Token", (
"room_key",
@@ -166,6 +238,7 @@ class StreamToken(
"push_rules_key",
"to_device_key",
"device_list_key",
+ "groups_key",
))
):
_SEPARATOR = "_"
@@ -178,7 +251,7 @@ class StreamToken(
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys)
- except:
+ except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
@@ -204,6 +277,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
+ or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):
@@ -263,7 +337,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
- except:
+ except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -272,7 +346,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
- except:
+ except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 2a2360ab5d..680ea928c7 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,20 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
-from synapse.util.logcontext import PreserveLoggingContext
-
-from twisted.internet import defer, reactor, task
-
-import time
import logging
+from itertools import islice
-logger = logging.getLogger(__name__)
+import attr
+from twisted.internet import defer, task
-class DeferredTimedOutError(SynapseError):
- def __init__(self):
- super(DeferredTimedOutError, self).__init__(504, "Timed out")
+from synapse.util.logcontext import PreserveLoggingContext
+
+logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
@@ -35,16 +31,27 @@ def unwrapFirstError(failure):
return failure.value.subFailure
+@attr.s
class Clock(object):
- """A small utility that obtains current time-of-day so that time may be
- mocked during unit-tests.
+ """
+ A Clock wraps a Twisted reactor and provides utilities on top of it.
- TODO(paul): Also move the sleep() functionality into it
+ Args:
+ reactor: The Twisted reactor to use.
"""
+ _reactor = attr.ib()
+
+ @defer.inlineCallbacks
+ def sleep(self, seconds):
+ d = defer.Deferred()
+ with PreserveLoggingContext():
+ self._reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def time(self):
"""Returns the current system time in seconds since epoch."""
- return time.time()
+ return self._reactor.seconds()
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
@@ -59,9 +66,10 @@ class Clock(object):
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
"""
- l = task.LoopingCall(f)
- l.start(msec / 1000.0, now=False)
- return l
+ call = task.LoopingCall(f)
+ call.clock = self._reactor
+ call.start(msec / 1000.0, now=False)
+ return call
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
@@ -77,61 +85,27 @@ class Clock(object):
callback(*args, **kwargs)
with PreserveLoggingContext():
- return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
+ return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel()
- except:
+ except Exception:
if not ignore_errs:
raise
- def time_bound_deferred(self, given_deferred, time_out):
- if given_deferred.called:
- return given_deferred
-
- ret_deferred = defer.Deferred()
- def timed_out_fn():
- e = DeferredTimedOutError()
+def batch_iter(iterable, size):
+ """batch an iterable up into tuples with a maximum size
- try:
- ret_deferred.errback(e)
- except:
- pass
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
- try:
- given_deferred.cancel()
- except:
- pass
-
- timer = None
-
- def cancel(res):
- try:
- self.cancel_call_later(timer)
- except:
- pass
- return res
-
- ret_deferred.addBoth(cancel)
-
- def success(res):
- try:
- ret_deferred.callback(res)
- except:
- pass
-
- return res
-
- def err(res):
- try:
- ret_deferred.errback(res)
- except:
- pass
-
- given_deferred.addCallbacks(callback=success, errback=err)
-
- timer = self.call_later(time_out, timed_out_fn)
-
- return ret_deferred
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1453faf0ef..a7094e2fb4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,38 +13,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+import logging
+from contextlib import contextmanager
+
+from six.moves import range
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
-from twisted.internet import defer, reactor
+from synapse.util import Clock, logcontext, unwrapFirstError
from .logcontext import (
- PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
)
-from synapse.util import unwrapFirstError
-
-from contextlib import contextmanager
-
-import logging
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def sleep(seconds):
- d = defer.Deferred()
- with PreserveLoggingContext():
- reactor.callLater(seconds, d.callback, seconds)
- res = yield d
- defer.returnValue(res)
-
-
-def run_on_reactor():
- """ This will cause the rest of the function to be invoked upon the next
- iteration of the main loop
- """
- return sleep(0)
-
-
class ObservableDeferred(object):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
@@ -53,6 +43,11 @@ class ObservableDeferred(object):
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
+
+ NB that it does not attempt to do anything with logcontexts; in general
+ you should probably make_deferred_yieldable the deferreds
+ returned by `observe`, and ensure that the original deferred runs its
+ callbacks in the sentinel logcontext.
"""
__slots__ = ["_deferred", "_observers", "_result"]
@@ -68,7 +63,7 @@ class ObservableDeferred(object):
try:
# TODO: Handle errors here.
self._observers.pop().callback(r)
- except:
+ except Exception:
pass
return r
@@ -78,7 +73,7 @@ class ObservableDeferred(object):
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
- except:
+ except Exception:
pass
if consumeErrors:
@@ -151,77 +146,19 @@ def concurrently_execute(func, args, limit):
def _concurrently_execute_inner():
try:
while True:
- yield func(it.next())
+ yield func(next(it))
except StopIteration:
pass
- return preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(_concurrently_execute_inner)()
- for _ in xrange(limit)
+ return logcontext.make_deferred_yieldable(defer.gatherResults([
+ run_in_background(_concurrently_execute_inner)
+ for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object):
- """Linearizes access to resources based on a key. Useful to ensure only one
- thing is happening at a time on a given resource.
-
- Example:
-
- with (yield linearizer.queue("test_key")):
- # do some work.
-
- """
- def __init__(self, name=None):
- if name is None:
- self.name = id(self)
- else:
- self.name = name
- self.key_to_defer = {}
-
- @defer.inlineCallbacks
- def queue(self, key):
- # If there is already a deferred in the queue, we pull it out so that
- # we can wait on it later.
- # Then we replace it with a deferred that we resolve *after* the
- # context manager has exited.
- # We only return the context manager after the previous deferred has
- # resolved.
- # This all has the net effect of creating a chain of deferreds that
- # wait for the previous deferred before starting their work.
- current_defer = self.key_to_defer.get(key)
-
- new_defer = defer.Deferred()
- self.key_to_defer[key] = new_defer
-
- if current_defer:
- logger.info(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key
- )
- try:
- with PreserveLoggingContext():
- yield current_defer
- except:
- logger.exception("Unexpected exception in Linearizer")
-
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
-
- @contextmanager
- def _ctx_manager():
- try:
- yield
- finally:
- logger.info("Releasing linearizer lock %r for key %r", self.name, key)
- new_defer.callback(None)
- current_d = self.key_to_defer.get(key)
- if current_d is new_defer:
- self.key_to_defer.pop(key, None)
-
- defer.returnValue(_ctx_manager())
-
-
-class Limiter(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
- only a few thing happen at a time on a given resource.
+ only a few things happen at a time on a given resource.
Example:
@@ -229,22 +166,31 @@ class Limiter(object):
# do some work.
"""
- def __init__(self, max_count):
+ def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
- max_count(int): The maximum number of concurrent access
+ max_count(int): The maximum number of concurrent accesses
"""
+ if name is None:
+ self.name = id(self)
+ else:
+ self.name = name
+
+ if not clock:
+ from twisted.internet import reactor
+ clock = Clock(reactor)
+ self._clock = clock
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
- # the first element is the number of things executing
- # the second element is a list of deferreds for the things blocked from
- # executing.
+ # the first element is the number of things executing, and
+ # the second element is an OrderedDict, where the keys are deferreds for the
+ # things blocked from executing.
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
- entry = self.key_to_defer.setdefault(key, [0, []])
+ entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
@@ -252,27 +198,71 @@ class Limiter(object):
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
- entry[1].append(new_defer)
- with PreserveLoggingContext():
- yield new_defer
+ entry[1][new_defer] = 1
+
+ logger.info(
+ "Waiting to acquire linearizer lock %r for key %r", self.name, key,
+ )
+ try:
+ yield make_deferred_yieldable(new_defer)
+ except Exception as e:
+ if isinstance(e, CancelledError):
+ logger.info(
+ "Cancelling wait for linearizer lock %r for key %r",
+ self.name, key,
+ )
+ else:
+ logger.warn(
+ "Unexpected exception waiting for linearizer lock %r for key %r",
+ self.name, key,
+ )
+
+ # we just have to take ourselves back out of the queue.
+ del entry[1][new_defer]
+ raise
+
+ logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ entry[0] += 1
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # (This needs to happen while we hold the lock, and the context manager's exit
+ # code must be synchronous, so this is the only sensible place.)
+ yield self._clock.sleep(0)
- entry[0] += 1
+ else:
+ logger.info(
+ "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+ )
+ entry[0] += 1
@contextmanager
def _ctx_manager():
try:
yield
finally:
+ logger.info("Releasing linearizer lock %r for key %r", self.name, key)
+
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry[0] -= 1
- try:
- entry[1].pop(0).callback(None)
- except IndexError:
- # If nothing else is executing for this key then remove it
- # from the map
- if entry[0] == 0:
- self.key_to_defer.pop(key, None)
+
+ if entry[1]:
+ (next_def, _) = entry[1].popitem(last=False)
+
+ # we need to run the next thing in the sentinel context.
+ with PreserveLoggingContext():
+ next_def.callback(None)
+ elif entry[0] == 0:
+ # We were the last thing for this key: remove it from the
+ # map.
+ del self.key_to_defer[key]
defer.returnValue(_ctx_manager())
@@ -316,7 +306,7 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
- yield curr_writer
+ yield make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@@ -345,7 +335,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
- yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
+ yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
@@ -357,3 +347,69 @@ class ReadWriteLock(object):
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+ """
+ This error is raised by default when a L{Deferred} times out.
+ """
+
+
+def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+ """
+ Add a timeout to a deferred by scheduling it to be cancelled after
+ timeout seconds.
+
+ This is essentially a backport of deferred.addTimeout, which was introduced
+ in twisted 16.5.
+
+ If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+ unless a cancelable function was passed to its initialization or unless
+ a different on_timeout_cancel callable is provided.
+
+ Args:
+ deferred (defer.Deferred): deferred to be timed out
+ timeout (Number): seconds to time out after
+ reactor (twisted.internet.reactor): the Twisted reactor to use
+
+ on_timeout_cancel (callable): A callable which is called immediately
+ after the deferred times out, and not if this deferred is
+ otherwise cancelled before the timeout.
+
+ It takes an arbitrary value, which is the value of the deferred at
+ that exact point in time (probably a CancelledError Failure), and
+ the timeout.
+
+ The default callable (if none is provided) will translate a
+ CancelledError Failure into a DeferredTimeoutError.
+ """
+ timed_out = [False]
+
+ def time_it_out():
+ timed_out[0] = True
+ deferred.cancel()
+
+ delayed_call = reactor.callLater(timeout, time_it_out)
+
+ def convert_cancelled(value):
+ if timed_out[0]:
+ to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+ return to_call(value, timeout)
+ return value
+
+ deferred.addBoth(convert_cancelled)
+
+ def cancel_timeout(result):
+ # stop the pending call to cancel the deferred if it's been fired
+ if delayed_call.active():
+ delayed_call.cancel()
+ return result
+
+ deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise DeferredTimeoutError(timeout, "Deferred")
+ return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4adae96681..7b065b195e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,28 +13,87 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.metrics
import os
+import six
+from six.moves import intern
+
+from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
+
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
-metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+
+def get_cache_factor_for(cache_name):
+ env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper()
+ factor = os.environ.get(env_var)
+ if factor:
+ return float(factor)
+
+ return CACHE_SIZE_FACTOR
+
caches_by_name = {}
-# cache_counter = metrics.register_cache(
-# "cache",
-# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-# labels=["name"],
-# )
-
-
-def register_cache(name, cache):
- caches_by_name[name] = cache
- return metrics.register_cache(
- "cache",
- lambda: len(cache),
- name,
- )
+collectors_by_name = {}
+
+cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+
+response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_evicted = Gauge(
+ "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+)
+response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+
+
+def register_cache(cache_type, cache_name, cache):
+
+ # Check if the metric is already registered. Unregister it, if so.
+ # This usually happens during tests, as at runtime these caches are
+ # effectively singletons.
+ metric_name = "cache_%s_%s" % (cache_type, cache_name)
+ if metric_name in collectors_by_name.keys():
+ REGISTRY.unregister(collectors_by_name[metric_name])
+
+ class CacheMetric(object):
+
+ hits = 0
+ misses = 0
+ evicted_size = 0
+
+ def inc_hits(self):
+ self.hits += 1
+
+ def inc_misses(self):
+ self.misses += 1
+
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
+ def describe(self):
+ return []
+
+ def collect(self):
+ if cache_type == "response_cache":
+ response_cache_size.labels(cache_name).set(len(cache))
+ response_cache_hits.labels(cache_name).set(self.hits)
+ response_cache_evicted.labels(cache_name).set(self.evicted_size)
+ response_cache_total.labels(cache_name).set(self.hits + self.misses)
+ else:
+ cache_size.labels(cache_name).set(len(cache))
+ cache_hits.labels(cache_name).set(self.hits)
+ cache_evicted.labels(cache_name).set(self.evicted_size)
+ cache_total.labels(cache_name).set(self.hits + self.misses)
+
+ yield GaugeMetricFamily("__unused", "")
+
+ metric = CacheMetric()
+ REGISTRY.register(metric)
+ caches_by_name[cache_name] = cache
+ collectors_by_name[metric_name] = metric
+ return metric
KNOWN_KEYS = {
@@ -66,7 +125,9 @@ def intern_string(string):
return None
try:
- string = string.encode("ascii")
+ if six.PY2:
+ string = string.encode("ascii")
+
return intern(string)
except UnicodeEncodeError:
return string
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..f8a07df6b8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,25 +13,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import inspect
import logging
+import threading
+from collections import namedtuple
+
+import six
+from six import itervalues, string_types
+
+from twisted.internet import defer
+from synapse.util import logcontext, unwrapFirstError
from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError, logcontext
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
-from twisted.internet import defer
-from collections import namedtuple
-
-import functools
-import inspect
-import threading
-
-
logger = logging.getLogger(__name__)
@@ -39,12 +41,11 @@ _CacheSentinel = object()
class CacheEntry(object):
__slots__ = [
- "deferred", "sequence", "callbacks", "invalidated"
+ "deferred", "callbacks", "invalidated"
]
- def __init__(self, deferred, sequence, callbacks):
+ def __init__(self, deferred, callbacks):
self.deferred = deferred
- self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
@@ -62,7 +63,6 @@ class Cache(object):
"max_entries",
"name",
"keylen",
- "sequence",
"thread",
"metrics",
"_pending_deferred_cache",
@@ -75,13 +75,16 @@ class Cache(object):
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
+ evicted_callback=self._on_evicted,
)
self.name = name
self.keylen = keylen
- self.sequence = 0
self.thread = None
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("cache", name, self.cache)
+
+ def _on_evicted(self, evicted_count):
+ self.metrics.inc_evictions(evicted_count)
def check_thread(self):
expected_thread = self.thread
@@ -109,11 +112,10 @@ class Cache(object):
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- if val.sequence == self.sequence:
- val.callbacks.update(callbacks)
- if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
+ val.callbacks.update(callbacks)
+ if update_metrics:
+ self.metrics.inc_hits()
+ return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
@@ -133,12 +135,9 @@ class Cache(object):
self.check_thread()
entry = CacheEntry(
deferred=value,
- sequence=self.sequence,
callbacks=callbacks,
)
- entry.callbacks.update(callbacks)
-
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
@@ -146,13 +145,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
- if self.sequence == entry.sequence:
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- self.cache.set(key, result, entry.callbacks)
- else:
- entry.invalidate()
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, result, entry.callbacks)
else:
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
entry.invalidate()
return result
@@ -164,25 +175,29 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
+ self.cache.pop(key, None)
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
if entry:
entry.invalidate()
- self.cache.pop(key, None)
-
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
- self.sequence += 1
self.cache.del_multi(key)
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
@@ -190,8 +205,10 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
- self.sequence += 1
self.cache.clear()
+ for entry in itervalues(self._pending_deferred_cache):
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
class _CacheDescriptorBase(object):
@@ -294,7 +311,7 @@ class CacheDescriptor(_CacheDescriptorBase):
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
cache_context=cache_context)
- max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+ max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
self.max_entries = max_entries
self.tree = tree
@@ -376,9 +393,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string, try to convert to ascii to save
- # a bit of space in large caches
- if isinstance(cache_key, basestring):
+ # If our cache_key is a string on py2, try to convert to ascii
+ # to save a bit of space in large caches. Py3 does this
+ # internally automatically.
+ if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
@@ -549,7 +567,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
return results
return logcontext.make_deferred_yieldable(defer.gatherResults(
- cached_defers.values(),
+ list(cached_defers.values()),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
unwrapFirstError
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d4105822b3..6c0b5a4094 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches.lrucache import LruCache
-from collections import namedtuple
-from . import register_cache
-import threading
import logging
+import threading
+from collections import namedtuple
+from synapse.util.caches.lrucache import LruCache
+
+from . import register_cache
logger = logging.getLogger(__name__)
@@ -55,7 +56,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -107,34 +108,37 @@ class DictionaryCache(object):
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, full=False, known_absent=None):
+ def update(self, sequence, key, value, fetched_keys=None):
"""Updates the entry in the cache
Args:
sequence
- key
- value (dict): The value to update the cache with.
- full (bool): Whether the given value is the full dict, or just a
- partial subset there of. If not full then any existing entries
- for the key will be updated.
- known_absent (set): Set of keys that we know don't exist in the full
- dict.
+ key (K)
+ value (dict[X,Y]): The value to update the cache with.
+ fetched_keys (None|set[X]): All of the dictionary keys which were
+ fetched from the database.
+
+ If None, this is the complete value for key K. Otherwise, it
+ is used to infer a list of keys which we know don't exist in
+ the full dict.
"""
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- if known_absent is None:
- known_absent = set()
- if full:
- self._insert(key, value, known_absent)
+ if fetched_keys is None:
+ self._insert(key, value, set())
else:
- self._update_or_insert(key, value, known_absent)
+ self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent):
- entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
+ # We pop and reinsert as we need to tell the cache the size may have
+ # changed
+
+ entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
+ self.cache[key] = entry
def _insert(self, key, value, known_absent):
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..465adc54a8 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import register_cache
-
-from collections import OrderedDict
import logging
+from collections import OrderedDict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
@@ -52,19 +52,22 @@ class ExpiringCache(object):
self._cache = OrderedDict()
- self.metrics = register_cache(cache_name, self)
-
self.iterable = iterable
self._size_estimate = 0
+ self.metrics = register_cache("expiring", cache_name, self)
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
return
def f():
- self._prune_cache()
+ run_as_background_process(
+ "prune_cache_%s" % self._cache_name,
+ self._prune_cache,
+ )
self._clock.looping_call(f, self._expiry_ms / 2)
@@ -79,7 +82,11 @@ class ExpiringCache(object):
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- self._size_estimate -= len(value.value)
+ removed_len = len(value.value)
+ self.metrics.inc_evictions(removed_len)
+ self._size_estimate -= removed_len
+ else:
+ self.metrics.inc_evictions()
def __getitem__(self, key):
try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..b684f24e7b 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -14,8 +14,8 @@
# limitations under the License.
-from functools import wraps
import threading
+from functools import wraps
from synapse.util.caches.treecache import TreeCache
@@ -49,7 +49,24 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+ def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+ evicted_callback=None):
+ """
+ Args:
+ max_size (int):
+
+ keylen (int):
+
+ cache_type (type):
+ type of underlying cache to be used. Typically one of dict
+ or TreeCache.
+
+ size_callback (func(V) -> int | None):
+
+ evicted_callback (func(int)|None):
+ if not None, called on eviction with the size of the evicted
+ entry
+ """
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
- delete_node(todelete)
+ evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
+ if evicted_callback:
+ evicted_callback(evicted_len)
def synchronized(f):
@wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ deleted_len = 1
if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
+ deleted_len = size_callback(node.value)
+ cached_cache_len[0] -= deleted_len
for cb in node.callbacks:
cb()
node.callbacks.clear()
+ return deleted_len
@synchronized
def cache_get(key, default=None, callbacks=[]):
@@ -132,14 +154,21 @@ class LruCache(object):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
- if value != node.value:
+ # We sometimes store large objects, e.g. dicts, which cause
+ # the inequality check to take a long time. So let's only do
+ # the check if we have some callbacks to call.
+ if node.callbacks and value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
- if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
- cached_cache_len[0] += size_callback(value)
+ # We don't bother to protect this by value != node.value as
+ # generally size_callback will be cheap compared with equality
+ # checks. (For example, taking the size of two dicts is quicker
+ # than comparing them for equality.)
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+ cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..a8491b42d5 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from twisted.internet import defer
from synapse.util.async import ObservableDeferred
+from synapse.util.caches import register_cache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
class ResponseCache(object):
@@ -24,20 +31,69 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self, hs, timeout_ms=0):
+ def __init__(self, hs, name, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
+ self._name = name
+ self._metrics = register_cache(
+ "response_cache", name, self
+ )
+
+ def size(self):
+ return len(self.pending_result_cache)
+
+ def __len__(self):
+ return self.size()
+
def get(self, key):
+ """Look up the given key.
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if the request has completed, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ If there is no entry for the key, returns None. It is worth noting that
+ this means there is no way to distinguish a completed result of None
+ from an absent cache entry.
+
+ Args:
+ key (hashable):
+
+ Returns:
+ twisted.internet.defer.Deferred|None|E: None if there is no entry
+ for this key; otherwise either a deferred result or the result
+ itself.
+ """
result = self.pending_result_cache.get(key)
if result is not None:
+ self._metrics.inc_hits()
return result.observe()
else:
+ self._metrics.inc_misses()
return None
def set(self, key, deferred):
+ """Set the entry for the given key to the given deferred.
+
+ *deferred* should run its callbacks in the sentinel logcontext (ie,
+ you should wrap normal synapse deferreds with
+ logcontext.run_in_background).
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if *deferred* was already complete, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ Args:
+ key (hashable):
+ deferred (twisted.internet.defer.Deferred[T):
+
+ Returns:
+ twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+ result.
+ """
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -53,3 +109,52 @@ class ResponseCache(object):
result.addBoth(remove)
return result.observe()
+
+ def wrap(self, key, callback, *args, **kwargs):
+ """Wrap together a *get* and *set* call, taking care of logcontexts
+
+ First looks up the key in the cache, and if it is present makes it
+ follow the synapse logcontext rules and returns it.
+
+ Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+ follow the synapse logcontext rules, and adds the result to the cache.
+
+ Example usage:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ # etc
+ defer.returnValue(result)
+
+ result = yield response_cache.wrap(
+ key,
+ handle_request,
+ request,
+ )
+
+ Args:
+ key (hashable): key to get/set in the cache
+
+ callback (callable): function to call if the key is not found in
+ the cache
+
+ *args: positional parameters to pass to the callback, if it is used
+
+ **kwargs: named paramters to pass to the callback, if it is used
+
+ Returns:
+ twisted.internet.defer.Deferred: yieldable result
+ """
+ result = self.get(key)
+ if not result:
+ logger.info("[%s]: no cached result for [%s], calculating new one",
+ self._name, key)
+ d = run_in_background(callback, *args, **kwargs)
+ result = self.set(key, d)
+ elif not isinstance(result, defer.Deferred) or result.called:
+ logger.info("[%s]: using completed cached result for [%s]",
+ self._name, key)
+ else:
+ logger.info("[%s]: using incomplete cached result for [%s]",
+ self._name, key)
+ return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 941d873ab8..f2bde74dc5 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
-
-
-from blist import sorteddict
import logging
+from sortedcontainers import SortedDict
+
+from synapse.util import caches
logger = logging.getLogger(__name__)
@@ -32,16 +31,18 @@ class StreamChangeCache(object):
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
- def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
- self._max_size = int(max_size * CACHE_SIZE_FACTOR)
+
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
+ self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
self._entity_to_key = {}
- self._cache = sorteddict()
+ self._cache = SortedDict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- self.metrics = register_cache(self.name, self._cache)
+ self.metrics = caches.register_cache("cache", self.name, self._cache)
- for entity, stream_pos in prefilled_cache.items():
- self.entity_has_changed(entity, stream_pos)
+ if prefilled_cache:
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
@@ -65,22 +66,25 @@ class StreamChangeCache(object):
return False
def get_entities_changed(self, entities, stream_pos):
- """Returns subset of entities that have had new things since the
- given position. If the position is too old it will just return the given list.
+ """
+ Returns subset of entities that have had new things since the given
+ position. Entities unknown to the cache will be returned. If the
+ position is too old it will just return the given list.
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- keys = self._cache.keys()
- i = keys.bisect_right(stream_pos)
+ changed_entities = {
+ self._cache[k] for k in self._cache.islice(
+ start=self._cache.bisect_right(stream_pos),
+ )
+ }
- result = set(
- self._cache[k] for k in keys[i:]
- ).intersection(entities)
+ result = changed_entities.intersection(entities)
self.metrics.inc_hits()
else:
- result = entities
+ result = set(entities)
self.metrics.inc_misses()
return result
@@ -90,12 +94,13 @@ class StreamChangeCache(object):
"""
assert type(stream_pos) is int
+ if not self._cache:
+ # If we have no cache, nothing can have changed.
+ return False
+
if stream_pos >= self._earliest_known_stream_pos:
self.metrics.inc_hits()
- keys = self._cache.keys()
- i = keys.bisect_right(stream_pos)
-
- return i < len(keys)
+ return self._cache.bisect_right(stream_pos) < len(self._cache)
else:
self.metrics.inc_misses()
return True
@@ -107,10 +112,8 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- keys = self._cache.keys()
- i = keys.bisect_right(stream_pos)
-
- return [self._cache[k] for k in keys[i:]]
+ return [self._cache[k] for k in self._cache.islice(
+ start=self._cache.bisect_right(stream_pos))]
else:
return None
@@ -129,8 +132,10 @@ class StreamChangeCache(object):
self._entity_to_key[entity] = stream_pos
while len(self._cache) > self._max_size:
- k, r = self._cache.popitem()
- self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ k, r = self._cache.popitem(0)
+ self._earliest_known_stream_pos = max(
+ k, self._earliest_known_stream_pos,
+ )
self._entity_to_key.pop(r, None)
def get_max_pos_of_last_change(self, entity):
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index fcc341a6b7..dd4c9e6067 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from six import itervalues
+
SENTINEL = object()
@@ -49,7 +51,7 @@ class TreeCache(object):
if popped is SENTINEL:
return default
- node_and_keys = zip(nodes, key)
+ node_and_keys = list(zip(nodes, key))
node_and_keys.reverse()
node_and_keys.append((self.root, None))
@@ -76,7 +78,7 @@ def iterate_tree_cache_entry(d):
can contain dicts.
"""
if isinstance(d, dict):
- for value_d in d.itervalues():
+ for value_d in itervalues(d):
for value in iterate_tree_cache_entry(value_d):
yield value
else:
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e68f94ce77..194da87639 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,32 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_fn
-)
-
-from synapse.util import unwrapFirstError
-
import logging
+from twisted.internet import defer
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id):
- return preserve_context_over_fn(
- distributor.fire,
- "user_left_room", user=user, room_id=room_id
- )
+ distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id):
- return preserve_context_over_fn(
- distributor.fire,
- "user_joined_room", user=user, room_id=room_id
- )
+ distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object):
@@ -52,9 +42,7 @@ class Distributor(object):
model will do for today.
"""
- def __init__(self, suppress_failures=True):
- self.suppress_failures = suppress_failures
-
+ def __init__(self):
self.signals = {}
self.pre_registration = {}
@@ -64,7 +52,6 @@ class Distributor(object):
self.signals[name] = Signal(
name,
- suppress_failures=self.suppress_failures,
)
if name in self.pre_registration:
@@ -83,10 +70,18 @@ class Distributor(object):
self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs):
+ """Dispatches the given signal to the registered observers.
+
+ Runs the observers as a background process. Does not return a deferred.
+ """
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
- return self.signals[name].fire(*args, **kwargs)
+ run_as_background_process(
+ name,
+ self.signals[name].fire,
+ *args, **kwargs
+ )
class Signal(object):
@@ -99,9 +94,8 @@ class Signal(object):
method into all of the observers.
"""
- def __init__(self, name, suppress_failures):
+ def __init__(self, name):
self.name = name
- self.suppress_failures = suppress_failures
self.observers = []
def observe(self, observer):
@@ -111,7 +105,6 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
- @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -129,22 +122,17 @@ class Signal(object):
failure.type,
failure.value,
failure.getTracebackObject()))
- if not self.suppress_failures:
- return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
- with PreserveLoggingContext():
- deferreds = [
- do(observer)
- for observer in self.observers
- ]
-
- res = yield defer.gatherResults(
- deferreds, consumeErrors=True
- ).addErrback(unwrapFirstError)
+ deferreds = [
+ run_in_background(do, o)
+ for o in self.observers
+ ]
- defer.returnValue(res)
+ return make_deferred_yieldable(defer.gatherResults(
+ deferreds, consumeErrors=True,
+ ))
def __repr__(self):
return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..629ed44149
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from six.moves import queue
+
+from twisted.internet import threads
+
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+
+class BackgroundFileConsumer(object):
+ """A consumer that writes to a file like object. Supports both push
+ and pull producers
+
+ Args:
+ file_obj (file): The file like object to write to. Closed when
+ finished.
+ reactor (twisted.internet.reactor): the Twisted reactor to use
+ """
+
+ # For PushProducers pause if we have this many unwritten slices
+ _PAUSE_ON_QUEUE_SIZE = 5
+ # And resume once the size of the queue is less than this
+ _RESUME_ON_QUEUE_SIZE = 2
+
+ def __init__(self, file_obj, reactor):
+ self._file_obj = file_obj
+
+ self._reactor = reactor
+
+ # Producer we're registered with
+ self._producer = None
+
+ # True if PushProducer, false if PullProducer
+ self.streaming = False
+
+ # For PushProducers, indicates whether we've paused the producer and
+ # need to call resumeProducing before we get more data.
+ self._paused_producer = False
+
+ # Queue of slices of bytes to be written. When producer calls
+ # unregister a final None is sent.
+ self._bytes_queue = queue.Queue()
+
+ # Deferred that is resolved when finished writing
+ self._finished_deferred = None
+
+ # If the _writer thread throws an exception it gets stored here.
+ self._write_exception = None
+
+ def registerProducer(self, producer, streaming):
+ """Part of IConsumer interface
+
+ Args:
+ producer (IProducer)
+ streaming (bool): True if push based producer, False if pull
+ based.
+ """
+ if self._producer:
+ raise Exception("registerProducer called twice")
+
+ self._producer = producer
+ self.streaming = streaming
+ self._finished_deferred = run_in_background(
+ threads.deferToThreadPool,
+ self._reactor,
+ self._reactor.getThreadPool(),
+ self._writer,
+ )
+ if not streaming:
+ self._producer.resumeProducing()
+
+ def unregisterProducer(self):
+ """Part of IProducer interface
+ """
+ self._producer = None
+ if not self._finished_deferred.called:
+ self._bytes_queue.put_nowait(None)
+
+ def write(self, bytes):
+ """Part of IProducer interface
+ """
+ if self._write_exception:
+ raise self._write_exception
+
+ if self._finished_deferred.called:
+ raise Exception("consumer has closed")
+
+ self._bytes_queue.put_nowait(bytes)
+
+ # If this is a PushProducer and the queue is getting behind
+ # then we pause the producer.
+ if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+ self._paused_producer = True
+ self._producer.pauseProducing()
+
+ def _writer(self):
+ """This is run in a background thread to write to the file.
+ """
+ try:
+ while self._producer or not self._bytes_queue.empty():
+ # If we've paused the producer check if we should resume the
+ # producer.
+ if self._producer and self._paused_producer:
+ if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+ self._reactor.callFromThread(self._resume_paused_producer)
+
+ bytes = self._bytes_queue.get()
+
+ # If we get a None (or empty list) then that's a signal used
+ # to indicate we should check if we should stop.
+ if bytes:
+ self._file_obj.write(bytes)
+
+ # If its a pull producer then we need to explicitly ask for
+ # more stuff.
+ if not self.streaming and self._producer:
+ self._reactor.callFromThread(self._producer.resumeProducing)
+ except Exception as e:
+ self._write_exception = e
+ raise
+ finally:
+ self._file_obj.close()
+
+ def wait(self):
+ """Returns a deferred that resolves when finished writing to file
+ """
+ return make_deferred_yieldable(self._finished_deferred)
+
+ def _resume_paused_producer(self):
+ """Gets called if we should resume producing after being paused
+ """
+ if self._paused_producer and self._producer:
+ self._paused_producer = False
+ self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 6322f0f55c..581c6052ac 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,18 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
+
+from canonicaljson import json
from frozendict import frozendict
def freeze(o):
- t = type(o)
- if t is dict:
+ if isinstance(o, dict):
return frozendict({k: freeze(v) for k, v in o.items()})
- if t is frozendict:
+ if isinstance(o, frozendict):
return o
- if t is str or t is unicode:
+ if isinstance(o, string_types):
return o
try:
@@ -36,11 +38,10 @@ def freeze(o):
def unfreeze(o):
- t = type(o)
- if t is dict or t is frozendict:
+ if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if t is str or t is unicode:
+ if isinstance(o, string_types):
return o
try:
@@ -49,3 +50,21 @@ def unfreeze(o):
pass
return o
+
+
+def _handle_frozendict(obj):
+ """Helper for EventEncoder. Makes frozendicts serializable by returning
+ the underlying dict
+ """
+ if type(obj) is frozendict:
+ # fishing the protected dict out of the object is a bit nasty,
+ # but we don't really want the overhead of copying the dict.
+ return obj._dict
+ raise TypeError('Object of type %s is not JSON serializable' %
+ obj.__class__.__name__)
+
+
+# A JSONEncoder which is capable of encoding frozendics without barfing
+frozendict_json_encoder = json.JSONEncoder(
+ default=_handle_frozendict,
+)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 45be47159a..2d7ddc1cbe 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-
import logging
+from twisted.web.resource import NoResource
+
logger = logging.getLogger(__name__)
@@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
+ # twisted requires all resources to be bytes
+ full_path = full_path.encode("utf-8")
+
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
+ for path_seg in full_path.split(b'/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
+ child_resource = NoResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
+ last_path_seg = full_path.split(b'/')[-1]
# if there is already a resource here, thieve its children and
# replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 990216145e..8dcae50b39 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -22,10 +22,10 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
-from twisted.internet import defer
-
-import threading
import logging
+import threading
+
+from twisted.internet import defer
logger = logging.getLogger(__name__)
@@ -42,23 +42,128 @@ try:
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
-except:
+except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage by returning None.
def get_thread_resource_usage():
return None
+class ContextResourceUsage(object):
+ """Object for tracking the resources used by a log context
+
+ Attributes:
+ ru_utime (float): user CPU time (in seconds)
+ ru_stime (float): system CPU time (in seconds)
+ db_txn_count (int): number of database transactions done
+ db_sched_duration_sec (float): amount of time spent waiting for a
+ database connection
+ db_txn_duration_sec (float): amount of time spent doing database
+ transactions (excluding scheduling time)
+ evt_db_fetch_count (int): number of events requested from the database
+ """
+
+ __slots__ = [
+ "ru_stime", "ru_utime",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "evt_db_fetch_count",
+ ]
+
+ def __init__(self, copy_from=None):
+ """Create a new ContextResourceUsage
+
+ Args:
+ copy_from (ContextResourceUsage|None): if not None, an object to
+ copy stats from
+ """
+ if copy_from is None:
+ self.reset()
+ else:
+ self.ru_utime = copy_from.ru_utime
+ self.ru_stime = copy_from.ru_stime
+ self.db_txn_count = copy_from.db_txn_count
+
+ self.db_txn_duration_sec = copy_from.db_txn_duration_sec
+ self.db_sched_duration_sec = copy_from.db_sched_duration_sec
+ self.evt_db_fetch_count = copy_from.evt_db_fetch_count
+
+ def copy(self):
+ return ContextResourceUsage(copy_from=self)
+
+ def reset(self):
+ self.ru_stime = 0.
+ self.ru_utime = 0.
+ self.db_txn_count = 0
+
+ self.db_txn_duration_sec = 0
+ self.db_sched_duration_sec = 0
+ self.evt_db_fetch_count = 0
+
+ def __repr__(self):
+ return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+ "db_txn_count='%r', db_txn_duration_sec='%r', "
+ "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
+ self.ru_stime,
+ self.ru_utime,
+ self.db_txn_count,
+ self.db_txn_duration_sec,
+ self.db_sched_duration_sec,
+ self.evt_db_fetch_count,)
+
+ def __iadd__(self, other):
+ """Add another ContextResourceUsage's stats to this one's.
+
+ Args:
+ other (ContextResourceUsage): the other resource usage object
+ """
+ self.ru_utime += other.ru_utime
+ self.ru_stime += other.ru_stime
+ self.db_txn_count += other.db_txn_count
+ self.db_txn_duration_sec += other.db_txn_duration_sec
+ self.db_sched_duration_sec += other.db_sched_duration_sec
+ self.evt_db_fetch_count += other.evt_db_fetch_count
+ return self
+
+ def __isub__(self, other):
+ self.ru_utime -= other.ru_utime
+ self.ru_stime -= other.ru_stime
+ self.db_txn_count -= other.db_txn_count
+ self.db_txn_duration_sec -= other.db_txn_duration_sec
+ self.db_sched_duration_sec -= other.db_sched_duration_sec
+ self.evt_db_fetch_count -= other.evt_db_fetch_count
+ return self
+
+ def __add__(self, other):
+ res = ContextResourceUsage(copy_from=self)
+ res += other
+ return res
+
+ def __sub__(self, other):
+ res = ContextResourceUsage(copy_from=self)
+ res -= other
+ return res
+
+
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
+
+ If a parent is given when creating a new context, then:
+ - logging fields are copied from the parent to the new context on entry
+ - when the new context exits, the cpu usage stats are copied from the
+ child to the parent
+
Args:
name (str): Name for the context for debugging.
+ parent_context (LoggingContext|None): The parent of the new context
"""
__slots__ = [
- "previous_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag", "alive",
+ "previous_context", "name", "parent_context",
+ "_resource_usage",
+ "usage_start",
+ "main_thread", "alive",
+ "request", "tag",
]
thread_local = threading.local()
@@ -80,32 +185,49 @@ class LoggingContext(object):
def stop(self):
pass
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
+ pass
+
+ def add_database_scheduled(self, sched_sec):
+ pass
+
+ def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
+ __bool__ = __nonzero__ # python3
sentinel = Sentinel()
- def __init__(self, name=None):
+ def __init__(self, name=None, parent_context=None):
self.previous_context = LoggingContext.current_context()
self.name = name
- self.ru_stime = 0.
- self.ru_utime = 0.
- self.db_txn_count = 0
- self.db_txn_duration = 0.
+
+ # track the resources used by this context so far
+ self._resource_usage = ContextResourceUsage()
+
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
self.usage_start = None
+
self.main_thread = threading.current_thread()
+ self.request = None
self.tag = ""
self.alive = True
+ self.parent_context = parent_context
+
def __str__(self):
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls):
- """Get the current logging context from thread local storage"""
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -133,18 +255,22 @@ class LoggingContext(object):
self.previous_context, old_context
)
self.alive = True
+
+ if self.parent_context is not None:
+ self.parent_context.copy_to(self)
+
return self
def __exit__(self, type, value, traceback):
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
- None to avoid suppressing any exeptions that were thrown.
+ None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.debug("Expected logging context %s has been lost", self)
+ logger.warn("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
@@ -154,47 +280,91 @@ class LoggingContext(object):
self.previous_context = None
self.alive = False
+ # if we have a parent, pass our CPU usage stats on
+ if self.parent_context is not None:
+ self.parent_context._resource_usage += self._resource_usage
+
+ # reset them in case we get entered again
+ self._resource_usage.reset()
+
def copy_to(self, record):
- """Copy fields from this context to the record"""
- for key, value in self.__dict__.items():
- setattr(record, key, value)
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
- record.ru_utime, record.ru_stime = self.get_resource_usage()
+ # 'request' is the only field we currently use in the logger, so that's
+ # all we need to copy
+ record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
return
- if self.usage_start and self.usage_end:
- self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
- self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
- self.usage_start = None
- self.usage_end = None
-
+ # If we haven't already started record the thread resource usage so
+ # far
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
+ return
+
+ # When we stop, let's record the cpu used since we started
+ if not self.usage_start:
+ logger.warning(
+ "Called stop on logcontext %s without calling start", self,
+ )
return
- if self.usage_start:
- self.usage_end = get_thread_resource_usage()
+ usage_end = get_thread_resource_usage()
+
+ self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
+ self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+
+ self.usage_start = None
def get_resource_usage(self):
- ru_utime = self.ru_utime
- ru_stime = self.ru_stime
+ """Get resources used by this logcontext so far.
- if self.usage_start and threading.current_thread() is self.main_thread:
+ Returns:
+ ContextResourceUsage: a *copy* of the object tracking resource
+ usage so far
+ """
+ # we always return a copy, for consistency
+ res = self._resource_usage.copy()
+
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = threading.current_thread() is self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
current = get_thread_resource_usage()
- ru_utime += current.ru_utime - self.usage_start.ru_utime
- ru_stime += current.ru_stime - self.usage_start.ru_stime
+ res.ru_utime += current.ru_utime - self.usage_start.ru_utime
+ res.ru_stime += current.ru_stime - self.usage_start.ru_stime
- return ru_utime, ru_stime
+ return res
- def add_database_transaction(self, duration_ms):
- self.db_txn_count += 1
- self.db_txn_duration += duration_ms / 1000.
+ def add_database_transaction(self, duration_sec):
+ self._resource_usage.db_txn_count += 1
+ self._resource_usage.db_txn_duration_sec += duration_sec
+
+ def add_database_scheduled(self, sched_sec):
+ """Record a use of the database pool
+
+ Args:
+ sched_sec (float): number of seconds it took us to get a
+ connection
+ """
+ self._resource_usage.db_sched_duration_sec += sched_sec
+
+ def record_event_fetch(self, event_count):
+ """Record a number of events being fetched from the db
+
+ Args:
+ event_count (int): number of events being fetched
+ """
+ self._resource_usage.evt_db_fetch_count += event_count
class LoggingContextFilter(logging.Filter):
@@ -248,7 +418,7 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.debug(
+ logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
@@ -261,105 +431,62 @@ class PreserveLoggingContext(object):
)
-class _PreservingContextDeferred(defer.Deferred):
- """A deferred that ensures that all callbacks and errbacks are called with
- the given logging context.
- """
- def __init__(self, context):
- self._log_context = context
- defer.Deferred.__init__(self)
-
- def addCallbacks(self, callback, errback=None,
- callbackArgs=None, callbackKeywords=None,
- errbackArgs=None, errbackKeywords=None):
- callback = self._wrap_callback(callback)
- errback = self._wrap_callback(errback)
- return defer.Deferred.addCallbacks(
- self, callback,
- errback=errback,
- callbackArgs=callbackArgs,
- callbackKeywords=callbackKeywords,
- errbackArgs=errbackArgs,
- errbackKeywords=errbackKeywords,
- )
+def preserve_fn(f):
+ """Function decorator which wraps the function with run_in_background"""
+ def g(*args, **kwargs):
+ return run_in_background(f, *args, **kwargs)
+ return g
- def _wrap_callback(self, f):
- def g(res, *args, **kwargs):
- with PreserveLoggingContext(self._log_context):
- res = f(res, *args, **kwargs)
- return res
- return g
+def run_in_background(f, *args, **kwargs):
+ """Calls a function, ensuring that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the function completes.
-def preserve_context_over_fn(fn, *args, **kwargs):
- """Takes a function and invokes it with the given arguments, but removes
- and restores the current logging context while doing so.
+ Useful for wrapping functions that return a deferred which you don't yield
+ on (for instance because you want to pass it to deferred.gatherResults()).
- If the result is a deferred, call preserve_context_over_deferred before
- returning it.
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
"""
- with PreserveLoggingContext():
- res = fn(*args, **kwargs)
+ current = LoggingContext.current_context()
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
- if isinstance(res, defer.Deferred):
- return preserve_context_over_deferred(res)
- else:
+ if not isinstance(res, defer.Deferred):
return res
-
-def preserve_context_over_deferred(deferred, context=None):
- """Given a deferred wrap it such that any callbacks added later to it will
- be invoked with the current context.
-
- Deprecated: this almost certainly doesn't do want you want, ie make
- the deferred follow the synapse logcontext rules: try
- ``make_deferred_yieldable`` instead.
- """
- if context is None:
- context = LoggingContext.current_context()
- d = _PreservingContextDeferred(context)
- deferred.chainDeferred(d)
- return d
-
-
-def preserve_fn(f):
- """Wraps a function, to ensure that the current context is restored after
- return from the function, and that the sentinel context is set once the
- deferred returned by the funtion completes.
-
- Useful for wrapping functions that return a deferred which you don't yield
- on.
- """
- def reset_context(result):
- LoggingContext.set_current_context(LoggingContext.sentinel)
- return result
-
- def g(*args, **kwargs):
- current = LoggingContext.current_context()
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred) and not res.called:
- # The function will have reset the context before returning, so
- # we need to restore it now.
- LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(reset_context)
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
return res
- return g
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
+ return res
-@defer.inlineCallbacks
def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules:
@@ -371,11 +498,27 @@ def make_deferred_yieldable(deferred):
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
- (This is more-or-less the opposite operation to preserve_fn.)
+ (This is more-or-less the opposite operation to run_in_background.)
"""
- with PreserveLoggingContext():
- r = yield deferred
- defer.returnValue(r)
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
+ return deferred
+
+
+def _set_context_cb(result, context):
+ """A callback function which just sets the logging context"""
+ LoggingContext.set_current_context(context)
+ return result
# modules to ignore in `logcontext_tracer`
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
new file mode 100644
index 0000000000..a46bc47ce3
--- /dev/null
+++ b/synapse/util/logformatter.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import traceback
+
+from six import StringIO
+
+
+class LogFormatter(logging.Formatter):
+ """Log formatter which gives more detail for exceptions
+
+ This is the same as the standard log formatter, except that when logging
+ exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+ sequence that led up to the point at which the exception was caught.
+ (Normally only stack frames between the point the exception was raised and
+ where it was caught are logged).
+ """
+ def __init__(self, *args, **kwargs):
+ super(LogFormatter, self).__init__(*args, **kwargs)
+
+ def formatException(self, ei):
+ sio = StringIO()
+ (typ, val, tb) = ei
+
+ # log the stack above the exception capture point if possible, but
+ # check that we actually have an f_back attribute to work around
+ # https://twistedmatrix.com/trac/ticket/9305
+
+ if tb and hasattr(tb.tb_frame, 'f_back'):
+ sio.write("Capture point (most recent call last):\n")
+ traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+ traceback.print_exception(typ, val, tb, None, sio)
+ s = sio.getvalue()
+ sio.close()
+ if s[-1:] == "\n":
+ s = s[:-1]
+ return s
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 3a83828d25..62a00189cc 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -14,13 +14,11 @@
# limitations under the License.
-from inspect import getcallargs
-from functools import wraps
-
-import logging
import inspect
+import logging
import time
-
+from functools import wraps
+from inspect import getcallargs
_TIME_FUNC_ID = 0
@@ -96,7 +94,7 @@ def time_function(f):
id = _TIME_FUNC_ID
_TIME_FUNC_ID += 1
- start = time.clock() * 1000
+ start = time.clock()
try:
_log_debug_as_f(
@@ -107,10 +105,10 @@ def time_function(f):
r = f(*args, **kwargs)
finally:
- end = time.clock() * 1000
+ end = time.clock()
_log_debug_as_f(
f,
- "[FUNC END] {%s-%d} %f",
+ "[FUNC END] {%s-%d} %.3f sec",
(func_name, id, end - start,),
)
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 97e0f00b67..14be3c7396 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.conch.manhole import ColoredManhole
-from twisted.conch.insults import insults
from twisted.conch import manhole_ssh
-from twisted.cred import checkers, portal
+from twisted.conch.insults import insults
+from twisted.conch.manhole import ColoredManhole
from twisted.conch.ssh.keys import Key
+from twisted.cred import checkers, portal
PUBLIC_KEY = (
"ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..97f1267380 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,40 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
+from functools import wraps
-from synapse.util.logcontext import LoggingContext
-import synapse.metrics
+from prometheus_client import Counter
-from functools import wraps
-import logging
+from twisted.internet import defer
+from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
+block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
-metrics = synapse.metrics.get_metrics_for(__name__)
+block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
-block_timer = metrics.register_distribution(
- "block_timer",
- labels=["block_name"]
-)
+block_ru_utime = Counter(
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
-block_ru_utime = metrics.register_distribution(
- "block_ru_utime", labels=["block_name"]
-)
+block_ru_stime = Counter(
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
-block_ru_stime = metrics.register_distribution(
- "block_ru_stime", labels=["block_name"]
-)
+block_db_txn_count = Counter(
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
-block_db_txn_count = metrics.register_distribution(
- "block_db_txn_count", labels=["block_name"]
-)
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = Counter(
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
-block_db_txn_duration = metrics.register_distribution(
- "block_db_txn_duration", labels=["block_name"]
-)
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = Counter(
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
def measure_func(name):
@@ -63,8 +60,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
- "clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+ "clock", "name", "start_context", "start",
+ "created_context",
+ "start_usage",
]
def __init__(self, clock, name):
@@ -75,23 +73,23 @@ class Measure(object):
self.created_context = False
def __enter__(self):
- self.start = self.clock.time_msec()
+ self.start = self.clock.time()
self.start_context = LoggingContext.current_context()
if not self.start_context:
self.start_context = LoggingContext("Measure")
self.start_context.__enter__()
self.created_context = True
- self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
- self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ self.start_usage = self.start_context.get_resource_usage()
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
- duration = self.clock.time_msec() - self.start
- block_timer.inc_by(duration, self.name)
+ duration = self.clock.time() - self.start
+
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
context = LoggingContext.current_context()
@@ -106,16 +104,19 @@ class Measure(object):
logger.warn("Expected context. (%r)", self.name)
return
- ru_utime, ru_stime = context.get_resource_usage()
-
- block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
- block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(
- context.db_txn_count - self.db_txn_count, self.name
- )
- block_db_txn_duration.inc_by(
- context.db_txn_duration - self.db_txn_duration, self.name
- )
+ current = context.get_resource_usage()
+ usage = current - self.start_usage
+ try:
+ block_ru_utime.labels(self.name).inc(usage.ru_utime)
+ block_ru_stime.labels(self.name).inc(usage.ru_stime)
+ block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
+ block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
+ block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
+ except ValueError:
+ logger.warn(
+ "Failed to save metrics! OLD: %r, NEW: %r",
+ self.start_usage, current
+ )
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
new file mode 100644
index 0000000000..4288312b8a
--- /dev/null
+++ b/synapse/util/module_loader.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+
+from synapse.config._base import ConfigError
+
+
+def load_module(provider):
+ """ Loads a module with its config
+ Take a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+
+ Returns
+ Tuple of (provider class, parsed config object)
+ """
+ # We need to import the module, and then pick the class out of
+ # that, so we split based on the last dot.
+ module, clz = provider['module'].rsplit(".", 1)
+ module = importlib.import_module(module)
+ provider_class = getattr(module, clz)
+
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
+
+ return provider_class, provider_config
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index 607161e7f0..a6c30e5265 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -14,6 +14,7 @@
# limitations under the License.
import phonenumbers
+
from synapse.api.errors import SynapseError
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..7deb38f2a7 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -13,17 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.api.errors import LimitExceededError
-
-from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
-
import collections
import contextlib
import logging
+from twisted.internet import defer
+
+from synapse.api.errors import LimitExceededError
+from synapse.util.logcontext import (
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
+)
logger = logging.getLogger(__name__)
@@ -91,13 +92,22 @@ class _PerHostRatelimiter(object):
self.window_size = window_size
self.sleep_limit = sleep_limit
- self.sleep_msec = sleep_msec
+ self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
+ # request_id objects for requests which have been slept
self.sleeping_requests = set()
+
+ # map from request_id object to Deferred for requests which are ready
+ # for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
+
+ # request id objects for requests which are in progress
self.current_processing = set()
+
+ # times at which we have recently (within the last window_size ms)
+ # received requests.
self.request_times = []
@contextlib.contextmanager
@@ -116,11 +126,15 @@ class _PerHostRatelimiter(object):
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
+
+ # remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
+ # reject the request if we already have too many queued up (either
+ # sleeping or in the ready queue).
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
@@ -133,9 +147,13 @@ class _PerHostRatelimiter(object):
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
- logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
+ logger.info(
+ "Ratelimiter: queueing request (queue now %i items)",
+ len(self.ready_request_queue),
+ )
+
return queue_defer
else:
return defer.succeed(None)
@@ -147,10 +165,9 @@ class _PerHostRatelimiter(object):
if len(self.request_times) > self.sleep_limit:
logger.debug(
- "Ratelimit [%s]: sleeping req",
- id(request_id),
+ "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
)
- ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
@@ -176,6 +193,9 @@ class _PerHostRatelimiter(object):
return r
def on_err(r):
+ # XXX: why is this necessary? this is called before we start
+ # processing the request so why would the request be in
+ # current_processing?
self.current_processing.discard(request_id)
return r
@@ -187,7 +207,7 @@ class _PerHostRatelimiter(object):
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
- return ret_defer
+ return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
logger.debug(
@@ -196,8 +216,10 @@ class _PerHostRatelimiter(object):
)
self.current_processing.discard(request_id)
try:
- request_id, deferred = self.ready_request_queue.popitem()
- self.current_processing.add(request_id)
- deferred.callback(None)
+ # start processing the next item on the queue.
+ _, deferred = self.ready_request_queue.popitem(last=False)
+
+ with PreserveLoggingContext():
+ deferred.callback(None)
except KeyError:
pass
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 4fa9d1a03c..8a3a06fd74 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,20 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.logcontext
-from twisted.internet import defer
-
-from synapse.api.errors import CodeMessageException
-
import logging
import random
+from twisted.internet import defer
+
+import synapse.util.logcontext
+from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
+ """Raised by the limiter (and federation client) to indicate that we are
+ are deliberately not attempting to contact a given server.
+
+ Args:
+ retry_last_ts (int): the unix ts in milliseconds of our last attempt
+ to contact the server. 0 indicates that the last attempt was
+ successful or that we've never actually attempted to connect.
+ retry_interval (int): the time in milliseconds to wait until the next
+ attempt.
+ destination (str): the domain in question
+ """
+
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)
@@ -189,10 +200,10 @@ class RetryDestinationLimiter(object):
yield self.store.set_destination_retry_timings(
self.destination, retry_last_ts, self.retry_interval
)
- except:
+ except Exception:
logger.exception(
- "Failed to store set_destination_retry_timings",
+ "Failed to store destination_retry_timings",
)
# we deliberately do this in the background.
- synapse.util.logcontext.preserve_fn(store_retry_timings)()
+ synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index f4a9abf83f..6c0f2bb0cf 100644
--- a/synapse/util/rlimit.py
+++ b/synapse/util/rlimit.py
@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import resource
import logging
-
+import resource
logger = logging.getLogger("synapse.app.homeserver")
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 95a6168e16..43d9db67ec 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,18 +16,20 @@
import random
import string
+from six.moves import range
+
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in xrange(length)
+ random.choice(_string_with_symbols) for _ in range(length)
)
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+ """Checks whether a given format of 3PID is allowed to be used on this HS
+
+ Args:
+ hs (synapse.server.HomeServer): server
+ medium (str): 3pid medium - e.g. email, msisdn
+ address (str): address within that medium (e.g. "wotan@matrix.org")
+ msisdns need to first have been canonicalised
+ Returns:
+ bool: whether the 3PID medium/address is allowed to be added to this HS
+ """
+
+ if hs.config.allowed_local_3pids:
+ for constraint in hs.config.allowed_local_3pids:
+ logger.debug(
+ "Checking 3PID %s (%s) against %s (%s)",
+ address, medium, constraint['pattern'], constraint['medium'],
+ )
+ if (
+ medium == constraint['medium'] and
+ re.match(constraint['pattern'], address)
+ ):
+ return True
+ else:
+ return True
+
+ return False
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 52086df465..1fbcd41115 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import subprocess
-import os
import logging
+import os
+import subprocess
logger = logging.getLogger(__name__)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7412fc57a4..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six.moves import range
+
class _Entry(object):
__slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
- _Entry(key) for key in xrange(last_key, then_key + 1)
+ _Entry(key) for key in range(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
@@ -91,7 +93,4 @@ class WheelTimer(object):
return ret
def __len__(self):
- l = 0
- for entry in self.entries:
- l += len(entry.queue)
- return l
+ return sum(len(entry.queue) for entry in self.entries)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index c4dd9ae2c7..d4680863d3 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.api.constants import Membership, EventTypes
+import logging
+import operator
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from six import iteritems, itervalues
+from six.moves import map
-import logging
+from twisted.internet import defer
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.utils import prune_event
+from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@@ -43,53 +46,66 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
-def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
- """ Returns dict of user_id -> list of events that user is allowed to
- see.
+def filter_events_for_client(store, user_id, events, is_peeking=False,
+ always_include_ids=frozenset()):
+ """
+ Check which events a user is allowed to see
Args:
- user_tuples (str, bool): (user id, is_peeking) for each user to be
- checked. is_peeking should be true if:
- * the user is not currently a member of the room, and:
- * the user has not been a member of the room since the
- given events
- events ([synapse.events.EventBase]): list of events to filter
- """
- forgotten = yield preserve_context_over_deferred(defer.gatherResults([
- defer.maybeDeferred(
- preserve_fn(store.who_forgot_in_room),
- room_id,
- )
- for room_id in frozenset(e.room_id for e in events)
- ], consumeErrors=True))
+ store (synapse.storage.DataStore): our datastore (can also be a worker
+ store)
+ user_id(str): user id to be checked
+ events(list[synapse.events.EventBase]): sequence of events to be checked
+ is_peeking(bool): should be True if:
+ * the user is not currently a member of the room, and:
+ * the user has not been a member of the room since the given
+ events
+ always_include_ids (set(event_id)): set of event ids to specifically
+ include (unless sender is ignored)
- # Set of membership event_ids that have been forgotten
- event_id_forgotten = frozenset(
- row["event_id"] for rows in forgotten for row in rows
+ Returns:
+ Deferred[list[synapse.events.EventBase]]
+ """
+ types = (
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, user_id),
+ )
+ event_id_to_state = yield store.get_state_for_events(
+ frozenset(e.event_id for e in events),
+ types=types,
)
- ignore_dict_content = yield store.get_global_account_data_by_type_for_users(
- "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples]
+ ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", user_id,
)
# FIXME: This will explode if people upload something incorrect.
- ignore_dict = {
- user_id: frozenset(
- content.get("ignored_users", {}).keys() if content else []
- )
- for user_id, content in ignore_dict_content.items()
- }
+ ignore_list = frozenset(
+ ignore_dict_content.get("ignored_users", {}).keys()
+ if ignore_dict_content else []
+ )
+
+ erased_senders = yield store.are_users_erased((e.sender for e in events))
- def allowed(event, user_id, is_peeking, ignore_list):
+ def allowed(event):
"""
Args:
event (synapse.events.EventBase): event to check
- user_id (str)
- is_peeking (bool)
- ignore_list (list): list of users to ignore
+
+ Returns:
+ None|EventBase:
+ None if the user cannot see this event at all
+
+ a redacted copy of the event if they can only see a redacted
+ version
+
+ the original event if they can see it as normal.
"""
if not event.is_state() and event.sender in ignore_list:
- return False
+ return None
+
+ if event.event_id in always_include_ids:
+ return event
state = event_id_to_state[event.event_id]
@@ -103,10 +119,6 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
- # if it was world_readable, it's easy: everyone can read it
- if visibility == "world_readable":
- return True
-
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
@@ -140,7 +152,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
- return True
+ return event
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
@@ -151,87 +163,203 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
- if membership_event.event_id not in event_id_forgotten:
- membership = membership_event.membership
+ membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
- return True
+ return event
+
+ # otherwise, it depends on the room visibility.
if visibility == "joined":
# we weren't a member at the time of the event, so we can't
# see this event.
- return False
+ return None
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
- return membership == Membership.INVITE
-
- else:
- # visibility is shared: user can also see the event if they have
- # become a member since the event
+ return (
+ event if membership == Membership.INVITE else None
+ )
+
+ elif visibility == "shared" and is_peeking:
+ # if the visibility is shared, users cannot see the event unless
+ # they have *subequently* joined the room (or were members at the
+ # time, of course)
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
- # we don't know when they left.
- return not is_peeking
+ # we don't know when they left. We just treat it as though they
+ # never joined, and restrict access.
+ return None
- defer.returnValue({
- user_id: [
- event
- for event in events
- if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, []))
- ]
- for user_id, is_peeking in user_tuples
- })
+ # the visibility is either shared or world_readable, and the user was
+ # not a member at the time. We allow it, provided the original sender
+ # has not requested their data to be erased, in which case, we return
+ # a redacted version.
+ if erased_senders[event.sender]:
+ return prune_event(event)
+ return event
-@defer.inlineCallbacks
-def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
- user_ids = set(u[0] for u in user_tuples)
- event_id_to_state = {}
- for event_id, context in event_id_to_context.items():
- state = yield store.get_events([
- e_id
- for key, e_id in context.current_state_ids.iteritems()
- if key == (EventTypes.RoomHistoryVisibility, "")
- or (key[0] == EventTypes.Member and key[1] in user_ids)
- ])
- event_id_to_state[event_id] = state
-
- res = yield filter_events_for_clients(
- store, user_tuples, events, event_id_to_state
- )
- defer.returnValue(res)
+ # check each event: gives an iterable[None|EventBase]
+ filtered_events = map(allowed, events)
+ # remove the None entries
+ filtered_events = filter(operator.truth, filtered_events)
-@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False):
- """
- Check which events a user is allowed to see
+ # we turn it into a list before returning it.
+ defer.returnValue(list(filtered_events))
- Args:
- user_id(str): user id to be checked
- events([synapse.events.EventBase]): list of events to be checked
- is_peeking(bool): should be True if:
- * the user is not currently a member of the room, and:
- * the user has not been a member of the room since the given
- events
- Returns:
- [synapse.events.EventBase]
- """
- types = (
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
+@defer.inlineCallbacks
+def filter_events_for_server(store, server_name, events):
+ # Whatever else we do, we need to check for senders which have requested
+ # erasure of their data.
+ erased_senders = yield store.are_users_erased(
+ e.sender for e in events,
)
- event_id_to_state = yield store.get_state_for_events(
+
+ def redact_disallowed(event, state):
+ # if the sender has been gdpr17ed, always return a redacted
+ # copy of the event.
+ if erased_senders[event.sender]:
+ logger.info(
+ "Sender of %s has been erased, redacting",
+ event.event_id,
+ )
+ return prune_event(event)
+
+ # state will be None if we decided we didn't need to filter by
+ # room membership.
+ if not state:
+ return event
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ if visibility in ["invited", "joined"]:
+ # We now loop through all state events looking for
+ # membership states for the requesting server to determine
+ # if the server is either in the room or has been invited
+ # into the room.
+ for ev in itervalues(state):
+ if ev.type != EventTypes.Member:
+ continue
+ try:
+ domain = get_domain_from_id(ev.state_key)
+ except Exception:
+ continue
+
+ if domain != server_name:
+ continue
+
+ memtype = ev.membership
+ if memtype == Membership.JOIN:
+ return event
+ elif memtype == Membership.INVITE:
+ if visibility == "invited":
+ return event
+ else:
+ # server has no users in the room: redact
+ return prune_event(event)
+
+ return event
+
+ # Next lets check to see if all the events have a history visibility
+ # of "shared" or "world_readable". If thats the case then we don't
+ # need to check membership (as we know the server is in the room).
+ event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- types=types
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
)
- res = yield filter_events_for_clients(
- store, [(user_id, is_peeking)], events, event_id_to_state
+
+ visibility_ids = set()
+ for sids in itervalues(event_to_state_ids):
+ hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist:
+ visibility_ids.add(hist)
+
+ # If we failed to find any history visibility events then the default
+ # is "shared" visiblity.
+ if not visibility_ids:
+ all_open = True
+ else:
+ event_map = yield store.get_events(visibility_ids)
+ all_open = all(
+ e.content.get("history_visibility") in (None, "shared", "world_readable")
+ for e in itervalues(event_map)
+ )
+
+ if all_open:
+ # all the history_visibility state affecting these events is open, so
+ # we don't need to filter by membership state. We *do* need to check
+ # for user erasure, though.
+ if erased_senders:
+ events = [
+ redact_disallowed(e, None)
+ for e in events
+ ]
+
+ defer.returnValue(events)
+
+ # Ok, so we're dealing with events that have non-trivial visibility
+ # rules, so we need to also get the memberships of the room.
+
+ # first, for each event we're wanting to return, get the event_ids
+ # of the history vis and membership state at those events.
+ event_to_state_ids = yield store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ )
)
- defer.returnValue(res.get(user_id, []))
+
+ # We only want to pull out member events that correspond to the
+ # server's domain.
+ #
+ # event_to_state_ids contains lots of duplicates, so it turns out to be
+ # cheaper to build a complete set of unique
+ # ((type, state_key), event_id) tuples, and then filter out the ones we
+ # don't want.
+ #
+ state_key_to_event_id_set = {
+ e
+ for key_to_eid in itervalues(event_to_state_ids)
+ for e in key_to_eid.items()
+ }
+
+ def include(typ, state_key):
+ if typ != EventTypes.Member:
+ return True
+
+ # we avoid using get_domain_from_id here for efficiency.
+ idx = state_key.find(":")
+ if idx == -1:
+ return False
+ return state_key[idx + 1:] == server_name
+
+ event_map = yield store.get_events([
+ e_id
+ for key, e_id in state_key_to_event_id_set
+ if include(key[0], key[1])
+ ])
+
+ event_to_state = {
+ e_id: {
+ key: event_map[inner_e_id]
+ for key, inner_e_id in iteritems(key_to_eid)
+ if inner_e_id in event_map
+ }
+ for e_id, key_to_eid in iteritems(event_to_state_ids)
+ }
+
+ defer.returnValue([
+ redact_disallowed(e, event_to_state[e.event_id])
+ for e in events
+ ])
|