diff options
Diffstat (limited to 'synapse')
288 files changed, 30549 insertions, 9750 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index da8ef90a77..f31cb9a3cb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.1" +__version__ = "0.28.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 03a215ab1b..f17fda6315 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,7 +23,8 @@ from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes from synapse.types import UserID -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR +from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -39,6 +40,10 @@ AuthEventTypes = ( GUEST_DEVICE_ID = "guest_device" +class _InvalidMacaroonException(Exception): + pass + + class Auth(object): """ FIXME: This class contains a mix of functions for authenticating users @@ -51,6 +56,9 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + register_cache("token_cache", self.token_cache) + @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): auth_events_ids = yield self.compute_auth_events( @@ -144,17 +152,8 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - - logger.debug("calling resolve_state_groups from check_host_in_room") - entry = yield self.state.resolve_state_groups( - room_id, latest_event_ids - ) - - ret = yield self.store.is_host_joined( - room_id, host, entry.state_group, entry.state - ) - defer.returnValue(ret) + latest_event_ids = yield self.store.is_host_joined(room_id, host) + defer.returnValue(latest_event_ids) def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: @@ -205,13 +204,12 @@ class Auth(object): ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - "User-Agent", - default=[""] + b"User-Agent", + default=[b""] )[0] if user and access_token and ip_addr: - preserve_context_over_fn( - self.store.insert_client_ip, - user=user, + self.store.insert_client_ip( + user_id=user.to_string(), access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -272,13 +270,17 @@ class Auth(object): rights (str): The operation being performed; the access token must allow this. Returns: - dict : dict that includes the user and the ID of their access token. + Deferred[dict]: dict that includes: + `user` (UserID) + `is_guest` (bool) + `token_id` (int|None): access token id. May be None if guest + `device_id` (str|None): device corresponding to access token Raises: AuthError if no user by that token exists or the token is invalid. """ try: - macaroon = pymacaroons.Macaroon.deserialize(token) - except Exception: # deserialize can throw more-or-less anything + user_id, guest = self._parse_and_validate_macaroon(token, rights) + except _InvalidMacaroonException: # doesn't look like a macaroon: treat it as an opaque token which # must be in the database. # TODO: it would be nice to get rid of this, but apparently some @@ -287,19 +289,8 @@ class Auth(object): defer.returnValue(r) try: - user_id = self.get_user_id_from_macaroon(macaroon) user = UserID.from_string(user_id) - self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, - user_id=user_id, - ) - - guest = False - for caveat in macaroon.caveats: - if caveat.caveat_id == "guest = true": - guest = True - if guest: # Guest access tokens are not stored in the database (there can # only be one access token per guest, anyway). @@ -371,6 +362,55 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) + def _parse_and_validate_macaroon(self, token, rights="access"): + """Takes a macaroon and tries to parse and validate it. This is cached + if and only if rights == access and there isn't an expiry. + + On invalid macaroon raises _InvalidMacaroonException + + Returns: + (user_id, is_guest) + """ + if rights == "access": + cached = self.token_cache.get(token, None) + if cached: + return cached + + try: + macaroon = pymacaroons.Macaroon.deserialize(token) + except Exception: # deserialize can throw more-or-less anything + # doesn't look like a macaroon: treat it as an opaque token which + # must be in the database. + # TODO: it would be nice to get rid of this, but apparently some + # people use access tokens which aren't macaroons + raise _InvalidMacaroonException() + + try: + user_id = self.get_user_id_from_macaroon(macaroon) + + has_expiry = False + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith("time "): + has_expiry = True + elif caveat.caveat_id == "guest = true": + guest = True + + self.validate_macaroon( + macaroon, rights, self.hs.config.expire_access_token, + user_id=user_id, + ) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN + ) + + if not has_expiry and rights == "access": + self.token_cache[token] = (user_id, guest) + + return user_id, guest + def get_user_id_from_macaroon(self, macaroon): """Retrieve the user_id given by the caveats on the macaroon. @@ -483,6 +523,14 @@ class Auth(object): ) def is_server_admin(self, user): + """ Check if the given user is a local server admin. + + Args: + user (str): mxid of user to check + + Returns: + bool: True if the user is an admin + """ return self.store.is_server_admin(user) @defer.inlineCallbacks @@ -624,7 +672,7 @@ def has_access_token(request): bool: False if no access_token was given, True otherwise. """ query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders("Authorization") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -644,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401): AuthError: If there isn't an access_token in the request. """ - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") if auth_headers: # Try the get the access_token from a "Authorization: Bearer" # header diff --git a/synapse/api/constants.py b/synapse/api/constants.py index ca23c9c460..5baba43966 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +16,9 @@ """Contains constants from the specification.""" +# the "depth" field on events is limited to 2**63 - 1 +MAX_DEPTH = 2**63 - 1 + class Membership(object): @@ -44,6 +48,7 @@ class JoinRules(object): class LoginType(object): PASSWORD = u"m.login.password" EMAIL_IDENTITY = u"m.login.email.identity" + MSISDN = u"m.login.msisdn" RECAPTCHA = u"m.login.recaptcha" DUMMY = u"m.login.dummy" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 921c457738..a9ff5576f3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,6 +17,9 @@ import logging +import simplejson as json +from six import iteritems + logger = logging.getLogger(__name__) @@ -45,32 +48,52 @@ class Codes(object): THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" + THREEPID_DENIED = "M_THREEPID_DENIED" INVALID_USERNAME = "M_INVALID_USERNAME" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" class CodeMessageException(RuntimeError): - """An exception with integer code and message string attributes.""" + """An exception with integer code and message string attributes. + Attributes: + code (int): HTTP error code + msg (str): string describing the error + """ def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code self.msg = msg - self.response_code_message = None def error_dict(self): return cs_error(self.msg) +class MatrixCodeMessageException(CodeMessageException): + """An error from a general matrix endpoint, eg. from a proxied Matrix API call. + + Attributes: + errcode (str): Matrix error code e.g 'M_FORBIDDEN' + """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN): + super(MatrixCodeMessageException, self).__init__(code, msg) + self.errcode = errcode + + class SynapseError(CodeMessageException): - """A base error which can be caught for all synapse events.""" + """A base exception type for matrix errors which have an errcode and error + message (as well as an HTTP status code). + + Attributes: + errcode (str): Matrix error code e.g 'M_FORBIDDEN' + """ def __init__(self, code, msg, errcode=Codes.UNKNOWN): """Constructs a synapse error. Args: code (int): The integer error code (an HTTP response code) msg (str): The human-readable error message. - err (str): The error code e.g 'M_FORBIDDEN' + errcode (str): The matrix error code e.g 'M_FORBIDDEN' """ super(SynapseError, self).__init__(code, msg) self.errcode = errcode @@ -81,12 +104,87 @@ class SynapseError(CodeMessageException): self.errcode, ) + @classmethod + def from_http_response_exception(cls, err): + """Make a SynapseError based on an HTTPResponseException + + This is useful when a proxied request has failed, and we need to + decide how to map the failure onto a matrix error to send back to the + client. + + An attempt is made to parse the body of the http response as a matrix + error. If that succeeds, the errcode and error message from the body + are used as the errcode and error message in the new synapse error. + + Otherwise, the errcode is set to M_UNKNOWN, and the error message is + set to the reason code from the HTTP response. + + Args: + err (HttpResponseException): + + Returns: + SynapseError: + """ + # try to parse the body as json, to get better errcode/msg, but + # default to M_UNKNOWN with the HTTP status as the error text + try: + j = json.loads(err.response) + except ValueError: + j = {} + errcode = j.get('errcode', Codes.UNKNOWN) + errmsg = j.get('error', err.msg) + + res = SynapseError(err.code, errmsg, errcode) + return res + class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass +class FederationDeniedError(SynapseError): + """An error raised when the server tries to federate with a server which + is not on its federation whitelist. + + Attributes: + destination (str): The destination which has been denied + """ + + def __init__(self, destination): + """Raised by federation client or server to indicate that we are + are deliberately not attempting to contact a given server because it is + not on our federation whitelist. + + Args: + destination (str): the domain in question + """ + + self.destination = destination + + super(FederationDeniedError, self).__init__( + code=403, + msg="Federation denied with %s." % (self.destination,), + errcode=Codes.FORBIDDEN, + ) + + +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): @@ -106,13 +204,11 @@ class UnrecognizedRequestError(SynapseError): class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.NOT_FOUND + def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): super(NotFoundError, self).__init__( 404, - "Not found", - **kwargs + msg, + errcode=errcode ) @@ -173,7 +269,6 @@ class LimitExceededError(SynapseError): errcode=Codes.LIMIT_EXCEEDED): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms - self.response_code_message = "Too Many Requests" def error_dict(self): return cs_error( @@ -203,7 +298,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): A dict representing the error response JSON. """ err = {"error": msg, "errcode": code} - for key, value in kwargs.iteritems(): + for key, value in iteritems(kwargs): err[key] = value return err @@ -243,6 +338,19 @@ class FederationError(RuntimeError): class HttpResponseException(CodeMessageException): + """ + Represents an HTTP-level failure of an outbound request + + Attributes: + response (str): body of response + """ def __init__(self, code, msg, response): - self.response = response + """ + + Args: + code (int): HTTP status code + msg (str): reason phrase from HTTP response status line + response (str): body of response + """ super(HttpResponseException, self).__init__(code, msg) + self.response = response diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index fb291d7fb9..db43219d24 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -13,11 +13,174 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.api.errors import SynapseError +from synapse.storage.presence import UserPresenceState from synapse.types import UserID, RoomID - from twisted.internet import defer -import ujson as json +import simplejson as json +import jsonschema +from jsonschema import FormatChecker + +FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "limit": { + "type": "number" + }, + "senders": { + "$ref": "#/definitions/user_id_array" + }, + "not_senders": { + "$ref": "#/definitions/user_id_array" + }, + # TODO: We don't limit event type values but we probably should... + # check types are valid event types + "types": { + "type": "array", + "items": { + "type": "string" + } + }, + "not_types": { + "type": "array", + "items": { + "type": "string" + } + } + } +} + +ROOM_FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "not_rooms": { + "$ref": "#/definitions/room_id_array" + }, + "rooms": { + "$ref": "#/definitions/room_id_array" + }, + "ephemeral": { + "$ref": "#/definitions/room_event_filter" + }, + "include_leave": { + "type": "boolean" + }, + "state": { + "$ref": "#/definitions/room_event_filter" + }, + "timeline": { + "$ref": "#/definitions/room_event_filter" + }, + "account_data": { + "$ref": "#/definitions/room_event_filter" + }, + } +} + +ROOM_EVENT_FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "limit": { + "type": "number" + }, + "senders": { + "$ref": "#/definitions/user_id_array" + }, + "not_senders": { + "$ref": "#/definitions/user_id_array" + }, + "types": { + "type": "array", + "items": { + "type": "string" + } + }, + "not_types": { + "type": "array", + "items": { + "type": "string" + } + }, + "rooms": { + "$ref": "#/definitions/room_id_array" + }, + "not_rooms": { + "$ref": "#/definitions/room_id_array" + }, + "contains_url": { + "type": "boolean" + } + } +} + +USER_ID_ARRAY_SCHEMA = { + "type": "array", + "items": { + "type": "string", + "format": "matrix_user_id" + } +} + +ROOM_ID_ARRAY_SCHEMA = { + "type": "array", + "items": { + "type": "string", + "format": "matrix_room_id" + } +} + +USER_FILTER_SCHEMA = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "schema for a Sync filter", + "type": "object", + "definitions": { + "room_id_array": ROOM_ID_ARRAY_SCHEMA, + "user_id_array": USER_ID_ARRAY_SCHEMA, + "filter": FILTER_SCHEMA, + "room_filter": ROOM_FILTER_SCHEMA, + "room_event_filter": ROOM_EVENT_FILTER_SCHEMA + }, + "properties": { + "presence": { + "$ref": "#/definitions/filter" + }, + "account_data": { + "$ref": "#/definitions/filter" + }, + "room": { + "$ref": "#/definitions/room_filter" + }, + "event_format": { + "type": "string", + "enum": ["client", "federation"] + }, + "event_fields": { + "type": "array", + "items": { + "type": "string", + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + "pattern": "^((?!\\\).)*$" + } + } + }, + "additionalProperties": False +} + + +@FormatChecker.cls_checks('matrix_room_id') +def matrix_room_id_validator(room_id_str): + return RoomID.from_string(room_id_str) + + +@FormatChecker.cls_checks('matrix_user_id') +def matrix_user_id_validator(user_id_str): + return UserID.from_string(user_id_str) class Filtering(object): @@ -52,98 +215,11 @@ class Filtering(object): # NB: Filters are the complete json blobs. "Definitions" are an # individual top-level key e.g. public_user_data. Filters are made of # many definitions. - - top_level_definitions = [ - "presence", "account_data" - ] - - room_level_definitions = [ - "state", "timeline", "ephemeral", "account_data" - ] - - for key in top_level_definitions: - if key in user_filter_json: - self._check_definition(user_filter_json[key]) - - if "room" in user_filter_json: - self._check_definition_room_lists(user_filter_json["room"]) - for key in room_level_definitions: - if key in user_filter_json["room"]: - self._check_definition(user_filter_json["room"][key]) - - if "event_fields" in user_filter_json: - if type(user_filter_json["event_fields"]) != list: - raise SynapseError(400, "event_fields must be a list of strings") - for field in user_filter_json["event_fields"]: - if not isinstance(field, basestring): - raise SynapseError(400, "Event field must be a string") - # Don't allow '\\' in event field filters. This makes matching - # events a lot easier as we can then use a negative lookbehind - # assertion to split '\.' If we allowed \\ then it would - # incorrectly split '\\.' See synapse.events.utils.serialize_event - if r'\\' in field: - raise SynapseError( - 400, r'The escape character \ cannot itself be escaped' - ) - - def _check_definition_room_lists(self, definition): - """Check that "rooms" and "not_rooms" are lists of room ids if they - are present - - Args: - definition(dict): The filter definition - Raises: - SynapseError: If there was a problem with this definition. - """ - # check rooms are valid room IDs - room_id_keys = ["rooms", "not_rooms"] - for key in room_id_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for room_id in definition[key]: - RoomID.from_string(room_id) - - def _check_definition(self, definition): - """Check if the provided definition is valid. - - This inspects not only the types but also the values to make sure they - make sense. - - Args: - definition(dict): The filter definition - Raises: - SynapseError: If there was a problem with this definition. - """ - # NB: Filters are the complete json blobs. "Definitions" are an - # individual top-level key e.g. public_user_data. Filters are made of - # many definitions. - if type(definition) != dict: - raise SynapseError( - 400, "Expected JSON object, not %s" % (definition,) - ) - - self._check_definition_room_lists(definition) - - # check senders are valid user IDs - user_id_keys = ["senders", "not_senders"] - for key in user_id_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for user_id in definition[key]: - UserID.from_string(user_id) - - # TODO: We don't limit event type values but we probably should... - # check types are valid event types - event_keys = ["types", "not_types"] - for key in event_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for event_type in definition[key]: - if not isinstance(event_type, basestring): - raise SynapseError(400, "Event type should be a string") + try: + jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, + format_checker=FormatChecker()) + except jsonschema.ValidationError as e: + raise SynapseError(400, e.message) class FilterCollection(object): @@ -253,19 +329,35 @@ class Filter(object): Returns: bool: True if the event matches """ - sender = event.get("sender", None) - if not sender: - # Presence events have their 'sender' in content.user_id - content = event.get("content") - # account_data has been allowed to have non-dict content, so check type first - if isinstance(content, dict): - sender = content.get("user_id") + # We usually get the full "events" as dictionaries coming through, + # except for presence which actually gets passed around as its own + # namedtuple type. + if isinstance(event, UserPresenceState): + sender = event.user_id + room_id = None + ev_type = "m.presence" + is_url = False + else: + sender = event.get("sender", None) + if not sender: + # Presence events had their 'sender' in content.user_id, but are + # now handled above. We don't know if anything else uses this + # form. TODO: Check this and probably remove it. + content = event.get("content") + # account_data has been allowed to have non-dict content, so + # check type first + if isinstance(content, dict): + sender = content.get("user_id") + + room_id = event.get("room_id", None) + ev_type = event.get("type", None) + is_url = "url" in event.get("content", {}) return self.check_fields( - event.get("room_id", None), + room_id, sender, - event.get("type", None), - "url" in event.get("content", {}) + ev_type, + is_url, ) def check_fields(self, room_id, sender, event_type, contains_url): diff --git a/synapse/app/_base.py b/synapse/app/_base.py new file mode 100644 index 0000000000..e4318cdfc3 --- /dev/null +++ b/synapse/app/_base.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import sys + +try: + import affinity +except Exception: + affinity = None + +from daemonize import Daemonize +from synapse.util import PreserveLoggingContext +from synapse.util.rlimit import change_resource_limit +from twisted.internet import error, reactor + +logger = logging.getLogger(__name__) + + +def start_worker_reactor(appname, config): + """ Run the reactor in the main process + + Daemonizes if necessary, and then configures some resources, before starting + the reactor. Pulls configuration from the 'worker' settings in 'config'. + + Args: + appname (str): application name which will be sent to syslog + config (synapse.config.Config): config object + """ + + logger = logging.getLogger(config.worker_app) + + start_reactor( + appname, + config.soft_file_limit, + config.gc_thresholds, + config.worker_pid_file, + config.worker_daemonize, + config.worker_cpu_affinity, + logger, + ) + + +def start_reactor( + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + cpu_affinity, + logger, +): + """ Run the reactor in the main process + + Daemonizes if necessary, and then configures some resources, before starting + the reactor + + Args: + appname (str): application name which will be sent to syslog + soft_file_limit (int): + gc_thresholds: + pid_file (str): name of pid file to write to if daemonize is True + daemonize (bool): true to run the reactor in a background process + cpu_affinity (int|None): cpu affinity mask + logger (logging.Logger): logger instance to pass to Daemonize + """ + + def run(): + # make sure that we run the reactor with the sentinel log context, + # otherwise other PreserveLoggingContext instances will get confused + # and complain when they see the logcontext arbitrarily swapping + # between the sentinel and `run` logcontexts. + with PreserveLoggingContext(): + logger.info("Running") + if cpu_affinity is not None: + if not affinity: + quit_with_error( + "Missing package 'affinity' required for cpu_affinity\n" + "option\n\n" + "Install by running:\n\n" + " pip install affinity\n\n" + ) + logger.info("Setting CPU affinity to %s" % cpu_affinity) + affinity.set_process_affinity_mask(0, cpu_affinity) + change_resource_limit(soft_file_limit) + if gc_thresholds: + gc.set_threshold(*gc_thresholds) + reactor.run() + + if daemonize: + daemon = Daemonize( + app=appname, + pid=pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +def quit_with_error(error_string): + message_lines = error_string.split("\n") + line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + sys.stderr.write("*" * line_length + '\n') + for line in message_lines: + sys.stderr.write(" %s\n" % (line.rstrip(),)) + sys.stderr.write("*" * line_length + '\n') + sys.exit(1) + + +def listen_tcp(bind_addresses, port, factory, backlog=50): + """ + Create a TCP socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenTCP( + port, + factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): + """ + Create an SSL socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenSSL( + port, + factory, + context_factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def check_bind_error(e, address, bind_addresses): + """ + This method checks an exception occurred while binding on 0.0.0.0. + If :: is specified in the bind addresses a warning is shown. + The exception is still raised otherwise. + + Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS + because :: binds on both IPv4 and IPv6 (as per RFC 3493). + When binding on 0.0.0.0 after :: this can safely be ignored. + + Args: + e (Exception): Exception that was caught. + address (str): Address on which binding was attempted. + bind_addresses (list): Addresses on which the service listens. + """ + if address == '0.0.0.0' and '::' in bind_addresses: + logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + else: + raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 1900930053..58f2c9d68c 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -13,37 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - -from synapse.server import HomeServer +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string - -from synapse import events - from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.appservice") @@ -56,19 +49,6 @@ class AppserviceSlaveStore( class AppserviceServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) @@ -84,19 +64,18 @@ class AppserviceServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse appservice now listening on port %d", port) @@ -105,45 +84,42 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ASReplicationHandler(self) + + +class ASReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(ASReplicationHandler, self).__init__(hs.get_datastore()) + self.appservice_handler = hs.get_application_service_handler() + + def on_rdata(self, stream_name, token, rows): + super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + + if stream_name == "events": + max_stream_id = self.store.get_room_max_stream_ordering() + run_in_background(self._notify_app_services, max_stream_id) + @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - appservice_handler = self.get_application_service_handler() - - @defer.inlineCallbacks - def replicate(results): - stream = results.get("events") - if stream: - max_stream_id = stream["position"] - yield appservice_handler.notify_interested_services(max_stream_id) - - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - replicate(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") def start(config_options): @@ -157,7 +133,7 @@ def start(config_options): assert config.worker_app == "synapse.app.appservice" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -186,33 +162,13 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-appservice", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-appservice", config) if __name__ == '__main__': diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 4d081eccd1..267d34c881 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -13,45 +13,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite +from synapse.crypto import context_factory from synapse.http.server import JsonResource -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.server import HomeServer -from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc +from twisted.internet import reactor +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.client_reader") @@ -63,26 +56,14 @@ class ClientReaderSlavedStore( DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, + TransactionStore, + SlavedClientIpStore, BaseSlavedStore, - ClientIpStore, # After BaseSlavedStore because the constructor is different ): pass class ClientReaderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) @@ -107,19 +88,18 @@ class ClientReaderServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -128,36 +108,23 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -171,7 +138,7 @@ def start(config_options): assert config.worker_app == "synapse.app.client_reader" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -189,36 +156,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-client-reader", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-client-reader", config) if __name__ == '__main__': diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py new file mode 100644 index 0000000000..b915d12d53 --- /dev/null +++ b/synapse/app/event_creator.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys + +import synapse +from synapse import events +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.profile import SlavedProfileStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.room import ( + RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet, + JoinRoomAliasServlet, +) +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string +from twisted.internet import reactor +from twisted.web.resource import NoResource + +logger = logging.getLogger("synapse.app.event_creator") + + +class EventCreatorSlavedStore( + DirectoryStore, + TransactionStore, + SlavedProfileStore, + SlavedAccountDataStore, + SlavedPusherStore, + SlavedReceiptsStore, + SlavedPushRuleStore, + SlavedDeviceStore, + SlavedClientIpStore, + SlavedApplicationServiceStore, + SlavedEventStore, + SlavedRegistrationStore, + RoomStore, + BaseSlavedStore, +): + pass + + +class EventCreatorServer(HomeServer): + def setup(self): + logger.info("Setting up.") + self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + RoomSendEventRestServlet(self).register(resource) + RoomMembershipRestServlet(self).register(resource) + RoomStateEventRestServlet(self).register(resource) + JoinRoomAliasServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ) + ) + + logger.info("Synapse event creator now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse event creator", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.event_creator" + + assert config.worker_replication_http_port is not None + + setup_logging(config, use_worker_options=True) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = EventCreatorServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.start_listening(config.worker_listeners) + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + + reactor.callWhenRunning(start) + + _base.start_worker_reactor("synapse-event-creator", config) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 90a4816753..c1dc66dd17 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -13,43 +13,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - +from synapse import events +from synapse.api.urls import FEDERATION_PREFIX +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.federation.transport.server import TransportLayerServer from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore -from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.api.urls import FEDERATION_PREFIX -from synapse.federation.transport.server import TransportLayerServer -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc +from twisted.internet import reactor +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.federation_reader") @@ -66,19 +58,6 @@ class FederationReaderSlavedStore( class FederationReaderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) @@ -98,19 +77,18 @@ class FederationReaderServer(HomeServer): FEDERATION_PREFIX: TransportLayerServer(self), }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation reader now listening on port %d", port) @@ -119,36 +97,22 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -162,7 +126,7 @@ def start(config_options): assert config.worker_app == "synapse.app.federation_reader" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -180,36 +144,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-federation-reader", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-federation-reader", config) if __name__ == '__main__': diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 411e47d98d..a08af83a4c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -13,69 +13,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - -from synapse.server import HomeServer +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.crypto import context_factory -from synapse.http.site import SynapseSite from synapse.federation import send_queue -from synapse.federation.units import Edu -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState -from synapse.util.async import sleep +from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource -from synapse import events - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc -import ujson as json - -logger = logging.getLogger("synapse.app.appservice") +logger = logging.getLogger("synapse.app.federation_sender") class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, SlavedDeviceStore, + SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, ): - pass + def __init__(self, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + + # We pull out the current federation stream position now so that we + # always have a known value for the federation position in memory so + # that we don't have to bounce via a deferred once when we start the + # replication streams. + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = ( + "SELECT stream_id FROM federation_stream_position" + " WHERE type = ?" + ) + sql = self.database_engine.convert_param_style(sql) + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 -class FederationSenderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn +class FederationSenderServer(HomeServer): def setup(self): logger.info("Setting up.") self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) @@ -91,19 +91,18 @@ class FederationSenderServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -112,41 +111,39 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - send_handler = FederationSenderHandler(self) - - send_handler.on_start() - - while True: - try: - args = store.stream_positions() - args.update((yield send_handler.stream_positions())) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - yield send_handler.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return FederationSenderReplicationHandler(self) + + +class FederationSenderReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) + self.send_handler = FederationSenderHandler(hs, self) + + def on_rdata(self, stream_name, token, rows): + super(FederationSenderReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + self.send_handler.process_replication_rows(stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args.update(self.send_handler.stream_positions()) + return args def start(config_options): @@ -160,7 +157,7 @@ def start(config_options): assert config.worker_app == "synapse.app.federation_sender" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -192,42 +189,27 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-federation-sender", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-federation-sender", config) class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ - def __init__(self, hs): + def __init__(self, hs, replication_client): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() + self.replication_client = replication_client + + self.federation_position = self.store.federation_out_pos_startup + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + + self._last_ack = self.federation_position self._room_serials = {} self._room_typing = {} @@ -239,98 +221,38 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) - @defer.inlineCallbacks def stream_positions(self): - stream_id = yield self.store.get_federation_out_pos("federation") - defer.returnValue({ - "federation": stream_id, - - # Ack stuff we've "processed", this should only be called from - # one process. - "federation_ack": stream_id, - }) + return {"federation": self.federation_position} - @defer.inlineCallbacks - def process_replication(self, result): + def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. - fed_stream = result.get("federation") - if fed_stream: - latest_id = int(fed_stream["position"]) - - # The federation stream containis a bunch of different types of - # rows that need to be handled differently. We parse the rows, put - # them into the appropriate collection and then send them off. - presence_to_send = {} - keyed_edus = {} - edus = {} - failures = {} - device_destinations = set() - - # Parse the rows in the stream - for row in fed_stream["rows"]: - position, typ, content_js = row - content = json.loads(content_js) - - if typ == send_queue.PRESENCE_TYPE: - destination = content["destination"] - state = UserPresenceState.from_dict(content["state"]) - - presence_to_send.setdefault(destination, []).append(state) - elif typ == send_queue.KEYED_EDU_TYPE: - key = content["key"] - edu = Edu(**content["edu"]) - - keyed_edus.setdefault( - edu.destination, {} - )[(edu.destination, tuple(key))] = edu - elif typ == send_queue.EDU_TYPE: - edu = Edu(**content) - - edus.setdefault(edu.destination, []).append(edu) - elif typ == send_queue.FAILURE_TYPE: - destination = content["destination"] - failure = content["failure"] - - failures.setdefault(destination, []).append(failure) - elif typ == send_queue.DEVICE_MESSAGE_TYPE: - device_destinations.add(content["destination"]) - else: - raise Exception("Unrecognised federation type: %r", typ) - - # We've finished collecting, send everything off - for destination, states in presence_to_send.items(): - self.federation_sender.send_presence(destination, states) - - for destination, edu_map in keyed_edus.items(): - for key, edu in edu_map.items(): - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=key, - ) - - for destination, edu_list in edus.items(): - for edu in edu_list: - self.federation_sender.send_edu( - edu.destination, edu.edu_type, edu.content, key=None, - ) + if stream_name == "federation": + send_queue.process_rows_for_federation(self.federation_sender, rows) + run_in_background(self.update_token, token) - for destination, failure_list in failures.items(): - for failure in failure_list: - self.federation_sender.send_failure(destination, failure) - - for destination in device_destinations: - self.federation_sender.send_device_messages(destination) + # We also need to poke the federation sender when new events happen + elif stream_name == "events": + self.federation_sender.notify_new_events(token) - # Record where we are in the stream. - yield self.store.update_federation_out_pos( - "federation", latest_id - ) + @defer.inlineCallbacks + def update_token(self, token): + try: + self.federation_position = token + + # We linearize here to ensure we don't have races updating the token + with (yield self._fed_position_linearizer.queue(None)): + if self._last_ack < self.federation_position: + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) - # We also need to poke the federation sender when new events happen - event_stream = result.get("events") - if event_stream: - latest_pos = event_stream["position"] - self.federation_sender.notify_new_events(latest_pos) + # We ACK this token over replication so that the master can drop + # its in memory queues + self.replication_client.send_federation_ack(self.federation_position) + self._last_ack = self.federation_position + except Exception: + logger.exception("Error updating federation stream position") if __name__ == '__main__': diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py new file mode 100644 index 0000000000..b349e3e3ce --- /dev/null +++ b/synapse/app/frontend_proxy.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys + +import synapse +from synapse import events +from synapse.api.errors import SynapseError +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.http.server import JsonResource +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, +) +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v2_alpha._base import client_v2_patterns +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + +logger = logging.getLogger("synapse.app.frontend_proxy") + + +class KeyUploadServlet(RestServlet): + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(KeyUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.http_client = hs.get_simple_http_client() + self.main_uri = hs.config.worker_main_http_uri + + @defer.inlineCallbacks + def on_POST(self, request, device_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if (requester.device_id is not None and + device_id != requester.device_id): + logger.warning("Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, device_id) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, + "To upload keys, you must pass device_id when authenticating" + ) + + if body: + # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) + headers = { + "Authorization": auth_headers, + } + result = yield self.http_client.post_json_get_json( + self.main_uri + request.uri, + body, + headers=headers, + ) + + defer.returnValue((200, result)) + else: + # Just interested in counts. + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + defer.returnValue((200, {"one_time_key_counts": result})) + + +class FrontendProxySlavedStore( + SlavedDeviceStore, + SlavedClientIpStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, +): + pass + + +class FrontendProxyServer(HomeServer): + def setup(self): + logger.info("Setting up.") + self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + KeyUploadServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ) + ) + + logger.info("Synapse client reader now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse frontend proxy", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.frontend_proxy" + + assert config.worker_main_http_uri is not None + + setup_logging(config, use_worker_options=True) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = FrontendProxyServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.start_listening(config.worker_listeners) + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + + reactor.callWhenRunning(start) + + _base.start_worker_reactor("synapse-frontend-proxy", config) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e0b87468fe..a0e465d644 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -13,59 +13,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import synapse - import gc import logging import os import sys -from synapse.config._base import ConfigError - -from synapse.python_dependencies import ( - check_requirements, DEPENDENCY_LINKS -) - -from synapse.rest import ClientRestResource -from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database - -from synapse.server import HomeServer -from twisted.internet import reactor, task, defer -from twisted.application import service -from twisted.web.resource import Resource, EncodingResourceWrapper -from twisted.web.static import File -from twisted.web.server import GzipEncoderFactory +import synapse +import synapse.config.logger +from synapse import events +from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \ + LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \ + STATIC_PREFIX, WEB_CLIENT_PREFIX +from synapse.app import _base +from synapse.app._base import quit_with_error, listen_ssl, listen_tcp +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.crypto import context_factory +from synapse.federation.transport.server import TransportLayerServer +from synapse.module_api import ModuleApi +from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect -from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource +from synapse.http.site import SynapseSite +from synapse.metrics import register_memory_metrics +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \ + check_requirements +from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.rest import ClientRestResource from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.api.urls import ( - FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, - SERVER_KEY_V2_PREFIX, -) -from synapse.config.homeserver import HomeServerConfig -from synapse.crypto import context_factory +from synapse.rest.media.v0.content_repository import ContentRepoResource +from synapse.server import HomeServer +from synapse.storage import are_all_users_on_domain +from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database +from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext -from synapse.metrics import register_memory_metrics, get_metrics_for -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX -from synapse.federation.transport.server import TransportLayerServer - +from synapse.util.manhole import manhole +from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole - -from synapse.http.site import SynapseSite - -from synapse import events - -from daemonize import Daemonize +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.web.resource import EncodingResourceWrapper, NoResource +from twisted.web.server import GzipEncoderFactory +from twisted.web.static import File logger = logging.getLogger("synapse.app.homeserver") @@ -90,7 +84,7 @@ def build_resource_for_web_client(hs): "\n" "You can also disable hosting of the webclient via the\n" "configuration option `web_client`\n" - % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]} + % {"dep": CONDITIONAL_REQUIREMENTS["web_client"].keys()[0]} ) syweb_path = os.path.dirname(syweb.__file__) webclient_path = os.path.join(syweb_path, "webclient") @@ -117,90 +111,121 @@ class SynapseHomeServer(HomeServer): resources = {} for res in listener_config["resources"]: for name in res["names"]: - if name == "client": - client_resource = ClientRestResource(self) - if res["compress"]: - client_resource = gz_wrap(client_resource) - - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - }) - - if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) - - if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) - - if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) - - if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) - - if name == "webclient": - resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) - - if name == "metrics" and self.get_config().enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(self) - - if name == "replication": - resources[REPLICATION_PREFIX] = ReplicationResource(self) + resources.update(self._configure_named_resource( + name, res.get("compress", False), + )) + + additional_resources = listener_config.get("additional_resources", {}) + logger.debug("Configuring additional resources: %r", + additional_resources) + module_api = ModuleApi(self, self.get_auth_handler()) + for path, resmodule in additional_resources.items(): + handler_cls, config = load_module(resmodule) + handler = handler_cls(config, module_api) + resources[path] = AdditionalResource(self, handler.handle_request) if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) else: - root_resource = Resource() + root_resource = NoResource() root_resource = create_resource_tree(resources, root_resource) if tls: - for address in bind_addresses: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=address - ) + listen_ssl( + bind_addresses, + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + self.tls_server_context_factory, + ) + else: - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse now listening on port %d", port) + def _configure_named_resource(self, name, compress=False): + """Build a resource map for a named resource + + Args: + name (str): named resource: one of "client", "federation", etc + compress (bool): whether to enable gzip compression for this + resource + + Returns: + dict[str, Resource]: map from path to HTTP resource + """ + resources = {} + if name == "client": + client_resource = ClientRestResource(self) + if compress: + client_resource = gz_wrap(client_resource) + + resources.update({ + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + }) + + if name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + if name in ["static", "client"]: + resources.update({ + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), + }) + + if name in ["media", "federation", "client"]: + if self.get_config().enable_media_repo: + media_repo = self.get_media_repository_resource() + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + elif name == "media": + raise ConfigError( + "'media' resource conflicts with enable_media_repo=False", + ) + + if name in ["keys", "federation"]: + resources.update({ + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), + }) + + if name == "webclient": + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) + + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationRestResource(self) + + return resources + def start_listening(self): config = self.get_config() @@ -208,17 +233,24 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": + listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + elif listener["type"] == "replication": bind_addresses = listener["bind_addresses"] - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + factory = ReplicationStreamProtocolFactory(self) + server_listener = reactor.listenTCP( + listener["port"], factory, interface=address + ) + reactor.addSystemEventTrigger( + "before", "shutdown", server_listener.stopListening, ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -239,29 +271,6 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - - -def quit_with_error(error_string): - message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') - for line in message_lines: - sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') - sys.exit(1) - def setup(config_options): """ @@ -286,7 +295,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - config.setup_logging() + synapse.config.logger.setup_logging(config, use_worker_options=False) # check any extra requirements we have now we have a config check_requirements(config) @@ -340,7 +349,7 @@ def setup(config_options): hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_replication_layer().start_get_pdu_cache() + hs.get_federation_client().start_get_pdu_cache() register_memory_metrics(hs) @@ -389,10 +398,15 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) - start_time = hs.get_clock().time() + clock = hs.get_clock() + start_time = clock.time() stats = {} + # Contains the list of processes we will be monitoring + # currently either 0 or 1 + stats_process = [] + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -401,41 +415,36 @@ def run(hs): if uptime < 0: uptime = 0 - # If the stats directory is empty then this is the first time we've - # reported stats. - first_time = not stats - stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime stats["total_users"] = yield hs.get_datastore().count_all_users() + total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + room_count = yield hs.get_datastore().get_room_count() stats["total_room_count"] = room_count stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - daily_messages = yield hs.get_datastore().count_daily_messages() - if daily_messages is not None: - stats["daily_messages"] = daily_messages - else: - stats.pop("daily_messages", None) - - if first_time: - # Add callbacks to report the synapse stats as metrics whenever - # prometheus requests them, typically every 30s. - # As some of the stats are expensive to calculate we only update - # them when synapse phones home to matrix.org every 24 hours. - metrics = get_metrics_for("synapse.usage") - metrics.add_callback("timestamp", lambda: stats["timestamp"]) - metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"]) - metrics.add_callback("total_users", lambda: stats["total_users"]) - metrics.add_callback("total_room_count", lambda: stats["total_room_count"]) - metrics.add_callback( - "daily_active_users", lambda: stats["daily_active_users"] - ) - metrics.add_callback( - "daily_messages", lambda: stats.get("daily_messages", 0) - ) + stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in r30_results.iteritems(): + stats["r30_users_" + name] = count + + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + if len(stats_process) > 0: + stats["memory_rss"] = 0 + stats["cpu_average"] = 0 + for process in stats_process: + stats["memory_rss"] += process.memory_info().rss + stats["cpu_average"] += int(process.cpu_percent(interval=None)) logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: @@ -446,37 +455,48 @@ def run(hs): except Exception as e: logger.warn("Error reporting stats: %s", e) - if hs.config.report_stats: - phone_home_task = task.LoopingCall(phone_stats_home) - logger.info("Scheduling stats reporting for 24 hour intervals") - phone_home_task.start(60 * 60 * 24, now=False) - - def in_thread(): - # Uncomment to enable tracing of log context changes. - # sys.settrace(logcontext_tracer) - with LoggingContext("run"): - change_resource_limit(hs.config.soft_file_limit) - if hs.config.gc_thresholds: - gc.set_threshold(*hs.config.gc_thresholds) - reactor.run() - - if hs.config.daemonize: - - if hs.config.print_pidfile: - print (hs.config.pid_file) - - daemon = Daemonize( - app="synapse-homeserver", - pid=hs.config.pid_file, - action=lambda: in_thread(), - auto_close_fds=False, - verbose=True, - logger=logger, - ) + def performance_stats_init(): + try: + import psutil + process = psutil.Process() + # Ensure we can fetch both, and make the initial request for cpu_percent + # so the next request will use this as the initial point. + process.memory_info().rss + process.cpu_percent(interval=None) + logger.info("report_stats can use psutil") + stats_process.append(process) + except (ImportError, AttributeError): + logger.warn( + "report_stats enabled but psutil is not installed or incorrect version." + " Disabling reporting of memory/cpu stats." + " Ensuring psutil is available will help matrix.org track performance" + " changes across releases." + ) - daemon.start() - else: - in_thread() + if hs.config.report_stats: + logger.info("Scheduling stats reporting for 3 hour intervals") + clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) + + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + + # We wait 5 minutes to send the first set of stats as the server can + # be quite busy the first few minutes + clock.call_later(5 * 60, phone_stats_home) + + if hs.config.daemonize and hs.config.print_pidfile: + print (hs.config.pid_file) + + _base.start_reactor( + "synapse-homeserver", + hs.config.soft_file_limit, + hs.config.gc_thresholds, + hs.config.pid_file, + hs.config.daemonize, + hs.config.cpu_affinity, + logger, + ) def main(): diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index ef17b158a5..fc8282bbc1 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -13,45 +13,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - +from synapse import events +from synapse.api.urls import ( + CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX +) +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.crypto import context_factory from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer -from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.api.urls import ( - CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX -) -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc +from twisted.internet import reactor +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.media_repository") @@ -59,27 +51,15 @@ logger = logging.getLogger("synapse.app.media_repository") class MediaRepositorySlavedStore( SlavedApplicationServiceStore, SlavedRegistrationStore, + SlavedClientIpStore, + TransactionStore, BaseSlavedStore, MediaRepositoryStore, - ClientIpStore, ): pass class MediaRepositoryServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) @@ -95,7 +75,7 @@ class MediaRepositoryServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) elif name == "media": - media_repo = MediaRepositoryResource(self) + media_repo = self.get_media_repository_resource() resources.update({ MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, @@ -104,19 +84,18 @@ class MediaRepositoryServer(HomeServer): ), }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse media repository now listening on port %d", port) @@ -125,36 +104,22 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -168,7 +133,14 @@ def start(config_options): assert config.worker_app == "synapse.app.media_repository" - setup_logging(config.worker_log_config, config.worker_log_file) + if config.enable_media_repo: + _base.quit_with_error( + "enable_media_repo must be disabled in the main synapse process\n" + "before the media repo can be run in a separate worker.\n" + "Please add ``enable_media_repo: false`` to the main config\n" + ) + + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -186,36 +158,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-media-repository", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-media-repository", config) if __name__ == '__main__': diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 073f2c2489..26930d1b3b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -13,39 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys import synapse - -from synapse.server import HomeServer +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.storage.roommember import RoomMemberStore +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.storage.engines import create_engine +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.util.async import sleep +from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string - -from synapse import events - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.pusher") @@ -82,42 +74,15 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - class PusherServer(HomeServer): - - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = PusherSlaveStore(self.get_db_conn(), self) logger.info("Finished setting up.") def remove_pusher(self, app_id, push_key, user_id): - http_client = self.get_simple_http_client() - replication_url = self.config.worker_replication_url - url = replication_url + "/remove_pushers" - return http_client.post_json_get_json(url, { - "remove": [{ - "app_id": app_id, - "push_key": push_key, - "user_id": user_id, - }] - }) + self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) def _listen_http(self, listener_config): port = listener_config["port"] @@ -129,19 +94,18 @@ class PusherServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse pusher now listening on port %d", port) @@ -150,88 +114,67 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return PusherReplicationHandler(self) + + +class PusherReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(PusherReplicationHandler, self).__init__(hs.get_datastore()) + + self.pusher_pool = hs.get_pusherpool() + + def on_rdata(self, stream_name, token, rows): + super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + run_in_background(self.poke_pushers, stream_name, token, rows) + @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - pusher_pool = self.get_pusherpool() - - def stop_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - pushers_for_user = pusher_pool.pushers.get(user_id, {}) - pusher = pushers_for_user.pop(key, None) - if pusher is None: - return - logger.info("Stopping pusher %r / %r", user_id, key) - pusher.on_stop() - - def start_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - logger.info("Starting pusher %r / %r", user_id, key) - return pusher_pool._refresh_pusher(app_id, pushkey, user_id) - - @defer.inlineCallbacks - def poke_pushers(results): - pushers_rows = set( - map(tuple, results.get("pushers", {}).get("rows", [])) - ) - deleted_pushers_rows = set( - map(tuple, results.get("deleted_pushers", {}).get("rows", [])) - ) - for row in sorted(pushers_rows | deleted_pushers_rows): - if row in deleted_pushers_rows: - user_id, app_id, pushkey = row[1:4] - stop_pusher(user_id, app_id, pushkey) - elif row in pushers_rows: - user_id = row[1] - app_id = row[5] - pushkey = row[8] - yield start_pusher(user_id, app_id, pushkey) - - stream = results.get("events") - if stream and stream["rows"]: - min_stream_id = stream["rows"][0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_notifications)( - min_stream_id, max_stream_id + def poke_pushers(self, stream_name, token, rows): + try: + if stream_name == "pushers": + for row in rows: + if row.deleted: + yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + yield self.start_pusher(row.user_id, row.app_id, row.pushkey) + elif stream_name == "events": + yield self.pusher_pool.on_new_notifications( + token, token, ) - - stream = results.get("receipts") - if stream and stream["rows"]: - rows = stream["rows"] - affected_room_ids = set(row[1] for row in rows) - min_stream_id = rows[0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_receipts)( - min_stream_id, max_stream_id, affected_room_ids + elif stream_name == "receipts": + yield self.pusher_pool.on_new_receipts( + token, token, set(row.room_id for row in rows) ) + except Exception: + logger.exception("Error poking pushers") + + def stop_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - poke_pushers(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + def start_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) def start(config_options): @@ -245,7 +188,7 @@ def start(config_options): assert config.worker_app == "synapse.app.pusher" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -274,34 +217,14 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): - ps.replicate() ps.get_pusherpool().start() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-pusher", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-pusher", config) if __name__ == '__main__': diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 3f29595256..7152b1deb4 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -13,105 +13,87 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import logging +import sys import synapse - -from synapse.api.constants import EventTypes, PresenceState +from synapse.api.constants import EventTypes +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging -from synapse.events import FrozenEvent -from synapse.handlers.presence import PresenceHandler -from synapse.http.site import SynapseSite +from synapse.handlers.presence import PresenceHandler, get_interested_parties from synapse.http.server import JsonResource -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.rest.client.v2_alpha import sync -from synapse.rest.client.v1 import events -from synapse.rest.client.v1.room import RoomInitialSyncRestServlet -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1 import events +from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.rest.client.v1.room import RoomInitialSyncRestServlet +from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer -from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine -from synapse.storage.presence import PresenceStore, UserPresenceState +from synapse.storage.presence import UserPresenceState from synapse.storage.roommember import RoomMemberStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import contextlib -import gc -import ujson as json +from six import iteritems logger = logging.getLogger("synapse.app.synchrotron") class SynchrotronSlavedStore( - SlavedPushRuleStore, - SlavedEventStore, SlavedReceiptsStore, SlavedAccountDataStore, SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedFilteringStore, SlavedPresenceStore, + SlavedGroupServerStore, SlavedDeviceInboxStore, SlavedDeviceStore, + SlavedPushRuleStore, + SlavedEventStore, + SlavedClientIpStore, RoomStore, BaseSlavedStore, - ClientIpStore, # After BaseSlavedStore because the constructor is different ): - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - did_forget = ( RoomMemberStore.__dict__["did_forget"] ) - # XXX: This is a bit broken because we don't persist the accepted list in a - # way that can be replicated. This means that we don't have a way to - # invalidate the cache correctly. - get_presence_list_accepted = PresenceStore.__dict__[ - "get_presence_list_accepted" - ] - get_presence_list_observers_accepted = PresenceStore.__dict__[ - "get_presence_list_observers_accepted" - ] - UPDATE_SYNCING_USERS_MS = 10 * 1000 class SynchrotronPresence(object): def __init__(self, hs): + self.hs = hs self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} - self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -121,17 +103,52 @@ class SynchrotronPresence(object): for state in active_presence } - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + # user_id -> last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} - self._sending_sync = False - self._need_to_send_sync = False - self.clock.looping_call( - self._send_syncing_users_regularly, - UPDATE_SYNCING_USERS_MS, + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, 10 * 1000 ) - reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + self.process_id = random_string(16) + logger.info("Presence process_id is %r", self.process_id) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in self.users_going_offline.items(): + if now - last_sync_ms > 10 * 1000: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? @@ -139,18 +156,16 @@ class SynchrotronPresence(object): get_states = PresenceHandler.get_states.__func__ get_state = PresenceHandler.get_state.__func__ - _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ - @defer.inlineCallbacks def user_syncing(self, user_id, affect_presence): if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_states = yield self.current_state_for_users([user_id]) - if prev_states[user_id].state == PresenceState.OFFLINE: - # TODO: Don't block the sync request on this HTTP hit. - yield self._send_syncing_users_now() + + # If we went from no in flight sync to some, notify replication + if self.user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because @@ -159,6 +174,10 @@ class SynchrotronPresence(object): if affect_presence and user_id in self.user_to_num_current_syncs: self.user_to_num_current_syncs[user_id] -= 1 + # If we went from one in flight sync to non, notify replication + if self.user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) + @contextlib.contextmanager def _user_syncing(): try: @@ -166,56 +185,12 @@ class SynchrotronPresence(object): finally: _end() - defer.returnValue(_user_syncing()) - - @defer.inlineCallbacks - def _on_shutdown(self): - # When the synchrotron is shutdown tell the master to clear the in - # progress syncs for this process - self.user_to_num_current_syncs.clear() - yield self._send_syncing_users_now() - - def _send_syncing_users_regularly(self): - # Only send an update if we aren't in the middle of sending one. - if not self._sending_sync: - preserve_fn(self._send_syncing_users_now)() - - @defer.inlineCallbacks - def _send_syncing_users_now(self): - if self._sending_sync: - # We don't want to race with sending another update. - # Instead we wait for that update to finish and send another - # update afterwards. - self._need_to_send_sync = True - return - - # Flag that we are sending an update. - self._sending_sync = True - - yield self.http_client.post_json_get_json(self.syncing_users_url, { - "process_id": self.process_id, - "syncing_users": [ - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ], - }) - - # Unset the flag as we are no longer sending an update. - self._sending_sync = False - if self._need_to_send_sync: - # If something happened while we were sending the update then - # we might need to send another update. - # TODO: Check if the update that was sent matches the current state - # as we only need to send an update if they are different. - self._need_to_send_sync = False - yield self._send_syncing_users_now() + return defer.succeed(_user_syncing()) @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): - parties = yield self._get_interested_parties( - states, calculate_remote_hosts=False - ) - room_ids_to_states, users_to_states, _ = parties + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), @@ -223,26 +198,24 @@ class SynchrotronPresence(object): ) @defer.inlineCallbacks - def process_replication(self, result): - stream = result.get("presence", {"rows": []}) - states = [] - for row in stream["rows"]: - ( - position, user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) = row - state = UserPresenceState( - user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) - self.user_to_current_state[user_id] = state - states.append(state) + def process_replication_rows(self, token, rows): + states = [UserPresenceState( + row.user_id, row.state, row.last_active_ts, + row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, + row.currently_active + ) for row in rows] + + for state in states: + self.user_to_current_state[row.user_id] = state + + stream_id = token + yield self.notify_from_replication(states, stream_id) - if states and "position" in stream: - stream_id = int(stream["position"]) - yield self.notify_from_replication(states, stream_id) + def get_currently_syncing_users(self): + return [ + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + if count > 0 + ] class SynchrotronTyping(object): @@ -257,16 +230,12 @@ class SynchrotronTyping(object): # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} - def process_replication(self, result): - stream = result.get("typing") - if stream: - self._latest_room_serial = int(stream["position"]) + def process_replication_rows(self, token, rows): + self._latest_room_serial = token - for row in stream["rows"]: - position, room_id, typing_json = row - typing = json.loads(typing_json) - self._room_serials[room_id] = position - self._room_typing[room_id] = typing + for row in rows: + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = row.user_ids class SynchrotronApplicationService(object): @@ -275,19 +244,6 @@ class SynchrotronApplicationService(object): class SynchrotronServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) @@ -315,19 +271,18 @@ class SynchrotronServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, ) + ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -336,135 +291,106 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - notifier = self.get_notifier() - presence_handler = self.get_presence_handler() - typing_handler = self.get_typing_handler() - - def notify_from_stream( - result, stream_name, stream_key, room=None, user=None - ): - stream = result.get(stream_name) - if stream: - position_index = stream["field_names"].index("position") - if room: - room_index = stream["field_names"].index(room) - if user: - user_index = stream["field_names"].index(user) - - users = () - rooms = () - for row in stream["rows"]: - position = row[position_index] - - if user: - users = (row[user_index],) - - if room: - rooms = (row[room_index],) - - notifier.on_new_event( - stream_key, position, users=users, rooms=rooms - ) + self.get_tcp_replication().start_replication(self) - @defer.inlineCallbacks - def notify_device_list_update(result): - stream = result.get("device_lists") - if not stream: - return + def build_tcp_replication(self): + return SyncReplicationHandler(self) - position_index = stream["field_names"].index("position") - user_index = stream["field_names"].index("user_id") + def build_presence_handler(self): + return SynchrotronPresence(self) - for row in stream["rows"]: - position = row[position_index] - user_id = row[user_index] + def build_typing_handler(self): + return SynchrotronTyping(self) - rooms = yield store.get_rooms_for_user(user_id) - room_ids = [r.room_id for r in rooms] - notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) +class SyncReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(SyncReplicationHandler, self).__init__(hs.get_datastore()) - @defer.inlineCallbacks - def notify(result): - stream = result.get("events") - if stream: - max_position = stream["position"] - for row in stream["rows"]: - position = row[0] - internal = json.loads(row[1]) - event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - notifier.on_new_room_event( - event, position, max_position, extra_users - ) + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + # NB this is a SynchrotronPresence, not a normal PresenceHandler + self.presence_handler = hs.get_presence_handler() + self.notifier = hs.get_notifier() - notify_from_stream( - result, "push_rules", "push_rules_key", user="user_id" - ) - notify_from_stream( - result, "user_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "room_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "tag_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "receipts", "receipt_key", room="room_id" - ) - notify_from_stream( - result, "typing", "typing_key", room="room_id" - ) - notify_from_stream( - result, "to_device", "to_device_key", user="user_id" - ) - yield notify_device_list_update(result) + def on_rdata(self, stream_name, token, rows): + super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + run_in_background(self.process_and_notify, stream_name, token, rows) - while True: - try: - args = store.stream_positions() - args.update(typing_handler.stream_positions()) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - typing_handler.process_replication(result) - yield presence_handler.process_replication(result) - yield notify(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def get_streams_to_replicate(self): + args = super(SyncReplicationHandler, self).get_streams_to_replicate() + args.update(self.typing_handler.stream_positions()) + return args - def build_presence_handler(self): - return SynchrotronPresence(self) + def get_currently_syncing_users(self): + return self.presence_handler.get_currently_syncing_users() - def build_typing_handler(self): - return SynchrotronTyping(self) + @defer.inlineCallbacks + def process_and_notify(self, stream_name, token, rows): + try: + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], + ) + elif stream_name in ("account_data", "tag_account_data",): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows], + ) + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, + ) + elif stream_name == "presence": + yield self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows], + ) + except Exception: + logger.exception("Error processing replication") def start(config_options): @@ -478,7 +404,7 @@ def start(config_options): assert config.worker_app == "synapse.app.synchrotron" - setup_logging(config.worker_log_config, config.worker_log_file) + setup_logging(config, use_worker_options=True) synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -496,33 +422,13 @@ def start(config_options): ss.setup() ss.start_listening(config.worker_listeners) - def run(): - with LoggingContext("run"): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_datastore().start_profiling() - ss.replicate() ss.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-synchrotron", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-synchrotron", config) if __name__ == '__main__': diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index c045588866..712dfa870e 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -23,14 +23,27 @@ import signal import subprocess import sys import yaml +import errno +import time SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] GREEN = "\x1b[1;32m" +YELLOW = "\x1b[1;33m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +def pid_running(pid): + try: + os.kill(pid, 0) + return True + except OSError as err: + if err.errno == errno.EPERM: + return True + return False + + def write(message, colour=NORMAL, stream=sys.stdout): if colour == NORMAL: stream.write(message + "\n") @@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout): stream.write(colour + message + NORMAL + "\n") +def abort(message, colour=RED, stream=sys.stderr): + write(message, colour, stream) + sys.exit(1) + + def start(configfile): write("Starting ...") args = SYNAPSE @@ -45,7 +63,8 @@ def start(configfile): try: subprocess.check_call(args) - write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) + write("started synapse.app.homeserver(%r)" % + (configfile,), colour=GREEN) except subprocess.CalledProcessError as e: write( "error starting (exit code: %d); see above for logs" % e.returncode, @@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile): def stop(pidfile, app): if os.path.exists(pidfile): pid = int(open(pidfile).read()) - os.kill(pid, signal.SIGTERM) - write("stopped %s" % (app,), colour=GREEN) + try: + os.kill(pid, signal.SIGTERM) + write("stopped %s" % (app,), colour=GREEN) + except OSError as err: + if err.errno == errno.ESRCH: + write("%s not running" % (app,), colour=YELLOW) + elif err.errno == errno.EPERM: + abort("Cannot stop %s: Operation not permitted" % (app,)) + else: + abort("Cannot stop %s: Unknown error" % (app,)) Worker = collections.namedtuple("Worker", [ @@ -98,7 +125,7 @@ def main(): "configfile", nargs="?", default="homeserver.yaml", - help="the homeserver config file, defaults to homserver.yaml", + help="the homeserver config file, defaults to homeserver.yaml", ) parser.add_argument( "-w", "--worker", @@ -157,6 +184,9 @@ def main(): worker_configfiles.append(worker_configfile) if options.all_processes: + # To start the main synapse with -a you need to add a worker file + # with worker_app == "synapse.app.homeserver" + start_stop_synapse = False worker_configdir = options.all_processes if not os.path.isdir(worker_configdir): write( @@ -173,10 +203,29 @@ def main(): with open(worker_configfile) as stream: worker_config = yaml.load(stream) worker_app = worker_config["worker_app"] - worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize # TODO print something more user friendly - worker_cache_factor = worker_config.get("synctl_cache_factor") + if worker_app == "synapse.app.homeserver": + # We need to special case all of this to pick up options that may + # be set in the main config file or in this worker config file. + worker_pidfile = ( + worker_config.get("pid_file") + or pidfile + ) + worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor + daemonize = worker_config.get("daemonize") or config.get("daemonize") + assert daemonize, "Main process must have daemonize set to true" + + # The master process doesn't support using worker_* config. + for key in worker_config: + if key == "worker_app": # But we allow worker_app + continue + assert not key.startswith("worker_"), \ + "Main process cannot use worker_* config" + else: + worker_pidfile = worker_config["worker_pid_file"] + worker_daemonize = worker_config["worker_daemonize"] + assert worker_daemonize, "In config %r: expected '%s' to be True" % ( + worker_configfile, "worker_daemonize") + worker_cache_factor = worker_config.get("synctl_cache_factor") workers.append(Worker( worker_app, worker_configfile, worker_pidfile, worker_cache_factor, )) @@ -190,10 +239,26 @@ def main(): if start_stop_synapse: stop(pidfile, "synapse.app.homeserver") - # TODO: Wait for synapse to actually shutdown before starting it again + # Wait for synapse to actually shutdown before starting it again + if action == "restart": + running_pids = [] + if start_stop_synapse and os.path.exists(pidfile): + running_pids.append(int(open(pidfile).read())) + for worker in workers: + if os.path.exists(worker.pidfile): + running_pids.append(int(open(worker.pidfile).read())) + if len(running_pids) > 0: + write("Waiting for process to exit before restarting...") + for running_pid in running_pids: + while pid_running(running_pid): + time.sleep(0.2) + write("All processes exited; now restarting...") if action == "start" or action == "restart": if start_stop_synapse: + # Check if synapse is already running + if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): + abort("synapse.app.homeserver already running") start(configfile) for worker in workers: diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py new file mode 100644 index 0000000000..5ba7e9b416 --- /dev/null +++ b/synapse/app/user_dir.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys + +import synapse +from synapse import events +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v2_alpha import user_directory +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.storage.user_directory import UserDirectoryStore +from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext, run_in_background +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string +from twisted.internet import reactor, defer +from twisted.web.resource import NoResource + +logger = logging.getLogger("synapse.app.user_dir") + + +class UserDirectorySlaveStore( + SlavedEventStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + SlavedClientIpStore, + UserDirectoryStore, + BaseSlavedStore, +): + def __init__(self, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + + self._current_state_delta_pos = events_max + + def stream_positions(self): + result = super(UserDirectorySlaveStore, self).stream_positions() + result["current_state_deltas"] = self._current_state_delta_pos + return result + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "current_state_deltas": + self._current_state_delta_pos = token + for row in rows: + self._curr_state_delta_stream_cache.entity_has_changed( + row.room_id, token + ) + return super(UserDirectorySlaveStore, self).process_replication_rows( + stream_name, token, rows + ) + + +class UserDirectoryServer(HomeServer): + def setup(self): + logger.info("Setting up.") + self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + user_directory.register_servlets(self, resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ) + ) + + logger.info("Synapse user_dir now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return UserDirectoryReplicationHandler(self) + + +class UserDirectoryReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) + self.user_directory = hs.get_user_directory_handler() + + def on_rdata(self, stream_name, token, rows): + super(UserDirectoryReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + if stream_name == "current_state_deltas": + run_in_background(self._notify_directory) + + @defer.inlineCallbacks + def _notify_directory(self): + try: + yield self.user_directory.notify_new_event() + except Exception: + logger.exception("Error notifiying user directory of state update") + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse user directory", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.user_dir" + + setup_logging(config, use_worker_options=True) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + if config.update_user_directory: + sys.stderr.write( + "\nThe update_user_directory must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``update_user_directory: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.update_user_directory = True + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ps = UserDirectoryServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def start(): + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + _base.start_worker_reactor("synapse-user-dir", config) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b0106a3597..5fdb579723 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.api.constants import EventTypes +from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import GroupID, get_domain_from_id from twisted.internet import defer import logging import re +from six import string_types + logger = logging.getLogger(__name__) @@ -80,12 +84,13 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, url=None, namespaces=None, hs_token=None, + def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, sender=None, id=None, protocols=None, rate_limited=True): self.token = token self.url = url self.hs_token = hs_token self.sender = sender + self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id @@ -124,29 +129,41 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) - if not isinstance(regex_obj.get("regex"), basestring): + group_id = regex_obj.get("group_id") + if group_id: + if not isinstance(group_id, str): + raise ValueError( + "Expected string for 'group_id' in ns '%s'" % ns + ) + try: + GroupID.from_string(group_id) + except Exception: + raise ValueError( + "Expected valid group ID for 'group_id' in ns '%s'" % ns + ) + + if get_domain_from_id(group_id) != self.server_name: + raise ValueError( + "Expected 'group_id' to be this host in ns '%s'" % ns + ) + + regex = regex_obj.get("regex") + if isinstance(regex, string_types): + regex_obj["regex"] = re.compile(regex) # Pre-compile regex + else: raise ValueError( "Expected string for 'regex' in ns '%s'" % ns ) return namespaces - def _matches_regex(self, test_string, namespace_key, return_obj=False): - if not isinstance(test_string, basestring): - logger.error( - "Expected a string to test regex against, but got %s", - test_string - ) - return False - + def _matches_regex(self, test_string, namespace_key): for regex_obj in self.namespaces[namespace_key]: - if re.match(regex_obj["regex"], test_string): - if return_obj: - return regex_obj - return True - return False + if regex_obj["regex"].match(test_string): + return regex_obj + return None def _is_exclusive(self, ns_key, test_string): - regex_obj = self._matches_regex(test_string, ns_key, return_obj=True) + regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False @@ -166,7 +183,14 @@ class ApplicationService(object): if not store: defer.returnValue(False) - member_list = yield store.get_users_in_room(event.room_id) + does_match = yield self._matches_user_in_member_list(event.room_id, store) + defer.returnValue(does_match) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = yield store.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) # check joined member events for user_id in member_list: @@ -219,10 +243,10 @@ class ApplicationService(object): ) def is_interested_in_alias(self, alias): - return self._matches_regex(alias, ApplicationService.NS_ALIASES) + return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) def is_interested_in_room(self, room_id): - return self._matches_regex(room_id, ApplicationService.NS_ROOMS) + return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) def is_exclusive_user(self, user_id): return ( @@ -239,6 +263,31 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + def get_exlusive_user_regexes(self): + """Get the list of regexes used to determine if a user is exclusively + registered by the AS + """ + return [ + regex_obj["regex"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if regex_obj["exclusive"] + ] + + def get_groups_for_user(self, user_id): + """Get the groups that this user is associated with by this AS + + Args: + user_id (str): The ID of the user. + + Returns: + iterable[str]: an iterable that yields group_id strings. + """ + return ( + regex_obj["group_id"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + ) + def is_rate_limited(self): return self.rate_limited diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6893610e71..00efff1464 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -72,7 +72,8 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", + timeout_ms=HOUR_IN_MS) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -192,9 +193,7 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - return self.protocol_meta_cache.get(key) or ( - self.protocol_meta_cache.set(key, _get()) - ) + return self.protocol_meta_cache.wrap(key, _get) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 68a9de17b8..6eddbc0828 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -51,7 +51,7 @@ components. from twisted.internet import defer from synapse.appservice import ApplicationServiceState -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure import logging @@ -106,7 +106,7 @@ class _ServiceQueuer(object): def enqueue(self, service, event): # if this service isn't being sent something self.queued_events.setdefault(service.id, []).append(event) - preserve_fn(self._send_request)(service) + run_in_background(self._send_request, service) @defer.inlineCallbacks def _send_request(self, service): @@ -123,7 +123,7 @@ class _ServiceQueuer(object): with Measure(self.clock, "servicequeuer.send"): try: yield self.txn_ctrl.send(service, events) - except: + except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) @@ -152,10 +152,10 @@ class _TransactionController(object): if sent: yield txn.complete(self.store) else: - preserve_fn(self._start_recoverer)(service) - except Exception as e: - logger.exception(e) - preserve_fn(self._start_recoverer)(service) + run_in_background(self._start_recoverer, service) + except Exception: + logger.exception("Error creating appservice transaction") + run_in_background(self._start_recoverer, service) @defer.inlineCallbacks def on_recovered(self, recoverer): @@ -176,17 +176,20 @@ class _TransactionController(object): @defer.inlineCallbacks def _start_recoverer(self, service): - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) - logger.info( - "Application service falling behind. Starting recoverer. AS ID %s", - service.id - ) - recoverer = self.recoverer_fn(service, self.on_recovered) - self.add_recoverers([recoverer]) - recoverer.recover() + try: + yield self.store.set_appservice_state( + service, + ApplicationServiceState.DOWN + ) + logger.info( + "Application service falling behind. Starting recoverer. AS ID %s", + service.id + ) + recoverer = self.recoverer_fn(service, self.on_recovered) + self.add_recoverers([recoverer]) + recoverer.recover() + except Exception: + logger.exception("Error starting AS recoverer") @defer.inlineCallbacks def _is_service_up(self, service): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..b748ed2b0a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -19,6 +19,8 @@ import os import yaml from textwrap import dedent +from six import integer_types + class ConfigError(Exception): pass @@ -49,7 +51,7 @@ Missing mandatory `server_name` config option. class Config(object): @staticmethod def parse_size(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -61,7 +63,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value second = 1000 minute = 60 * second @@ -82,21 +84,37 @@ class Config(object): return os.path.abspath(file_path) if file_path else file_path @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - "<server name> will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +266,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,33 +279,33 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) - with open(config_path, "wb") as config_file: - config_bytes, config = obj.generate_config( + with open(config_path, "w") as config_file: + config_str, config = obj.generate_config( config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), is_generating_file=True ) obj.invoke_all("generate_files", config) - config_file.write(config_bytes) - print ( + config_file.write(config_str) + print(( "A config file has been generated in %r for server name" " %r with corresponding SSL keys and self-signed" " certificates. Please review this file and customise it" " to your needs." - ) % (config_path, server_name) - print ( + ) % (config_path, server_name)) + print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print ( + print(( "Config file %r already exists. Generating any missing key" " files." - ) % (config_path,) + ) % (config_path,)) generate_keys = True parser = argparse.ArgumentParser( diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 82c50b8240..277305e184 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -17,10 +17,12 @@ from ._base import Config, ConfigError from synapse.appservice import ApplicationService from synapse.types import UserID -import urllib import yaml import logging +from six import string_types +from six.moves.urllib import parse as urlparse + logger = logging.getLogger(__name__) @@ -89,21 +91,21 @@ def _load_appservice(hostname, as_info, config_filename): "id", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: - if not isinstance(as_info.get(field), basestring): + if not isinstance(as_info.get(field), string_types): raise KeyError("Required string field: '%s' (%s)" % ( field, config_filename, )) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), basestring) and + if (not isinstance(as_info.get("url"), string_types) and as_info.get("url", "") is not None): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] - if urllib.quote(localpart) != localpart: + if urlparse.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) @@ -128,7 +130,7 @@ def _load_appservice(hostname, as_info, config_filename): "Expected namespace entry in %s to be an object," " but got %s", ns, regex_obj ) - if not isinstance(regex_obj.get("regex"), basestring): + if not isinstance(regex_obj.get("regex"), string_types): raise ValueError( "Missing/bad type 'regex' key in %s", regex_obj ) @@ -154,6 +156,7 @@ def _load_appservice(hostname, as_info, config_filename): ) return ApplicationService( token=as_info["as_token"], + hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 938f6f25f8..8109e5f95e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -41,7 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" - # service_url: "https://homesever.domain.com:8448" + # service_url: "https://homeserver.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 0030b5db1e..fe156b6930 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -71,6 +71,15 @@ class EmailConfig(Config): self.email_riot_base_url = email_config.get( "riot_base_url", None ) + self.email_smtp_user = email_config.get( + "smtp_user", None + ) + self.email_smtp_pass = email_config.get( + "smtp_pass", None + ) + self.require_transport_security = email_config.get( + "require_transport_security", False + ) if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -91,10 +100,17 @@ class EmailConfig(Config): # Defining a custom URL for Riot is only needed if email notifications # should contain links to a self-hosted installation of Riot; when set # the "app_name" setting is ignored. + # + # If your SMTP server requires authentication, the optional smtp_user & + # smtp_pass variables should be used + # #email: # enable_notifs: false # smtp_host: "localhost" # smtp_port: 25 + # smtp_user: "exampleusername" + # smtp_pass: "examplepassword" + # require_transport_security: False # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" # app_name: Matrix # template_dir: res/templates diff --git a/synapse/config/groups.py b/synapse/config/groups.py new file mode 100644 index 0000000000..997fa2881f --- /dev/null +++ b/synapse/config/groups.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class GroupsConfig(Config): + def read_config(self, config): + self.enable_group_creation = config.get("enable_group_creation", False) + self.group_creation_prefix = config.get("group_creation_prefix", "") + + def default_config(self, **kwargs): + return """\ + # Whether to allow non server admins to create groups on this server + enable_group_creation: false + + # If enabled, non server admins can only create groups with local parts + # starting with this prefix + # group_creation_prefix: "unofficial/" + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 0f890fc04a..bf19cfee29 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -33,6 +33,10 @@ from .jwt import JWTConfig from .password_auth_providers import PasswordAuthProviderConfig from .emailconfig import EmailConfig from .workers import WorkerConfig +from .push import PushConfig +from .spam_checker import SpamCheckerConfig +from .groups import GroupsConfig +from .user_directory import UserDirectoryConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -40,7 +44,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig,): + WorkerConfig, PasswordAuthProviderConfig, PushConfig, + SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,): pass diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..4b8fc063d0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -118,10 +118,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +140,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 77ded0ad25..6a7228dc2f 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -28,34 +28,35 @@ DEFAULT_LOG_CONFIG = Template(""" version: 1 formatters: - precise: - format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ -- %(message)s' + precise: + format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ +%(request)s - %(message)s' filters: - context: - (): synapse.util.logcontext.LoggingContextFilter - request: "" + context: + (): synapse.util.logcontext.LoggingContextFilter + request: "" handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] - level: INFO - console: - class: logging.StreamHandler - formatter: precise - filters: [context] + file: + class: logging.handlers.RotatingFileHandler + formatter: precise + filename: ${log_file} + maxBytes: 104857600 + backupCount: 10 + filters: [context] + console: + class: logging.StreamHandler + formatter: precise + filters: [context] loggers: synapse: level: INFO synapse.storage.SQL: + # beware: increasing this to DEBUG will make synapse log sensitive + # information such as access tokens. level: INFO root: @@ -68,21 +69,15 @@ class LoggingConfig(Config): def read_config(self, config): self.verbosity = config.get("verbose", 0) + self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) def default_config(self, config_dir_path, server_name, **kwargs): - log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") ) return """ - # Logging verbosity level. - verbose: 0 - - # File to write logging to - log_file: "%(log_file)s" - # A yaml python logging config file log_config: "%(log_config)s" """ % locals() @@ -90,6 +85,8 @@ class LoggingConfig(Config): def read_arguments(self, args): if args.verbose is not None: self.verbosity = args.verbose + if args.no_redirect_stdio is not None: + self.no_redirect_stdio = args.no_redirect_stdio if args.log_config is not None: self.log_config = args.log_config if args.log_file is not None: @@ -99,48 +96,68 @@ class LoggingConfig(Config): logging_group = parser.add_argument_group("logging") logging_group.add_argument( '-v', '--verbose', dest="verbose", action='count', - help="The verbosity level." + help="The verbosity level. Specify multiple times to increase " + "verbosity. (Ignored if --log-config is specified.)" ) logging_group.add_argument( '-f', '--log-file', dest="log_file", - help="File to log to." + help="File to log to. (Ignored if --log-config is specified.)" ) logging_group.add_argument( '--log-config', dest="log_config", default=None, help="Python logging config file" ) + logging_group.add_argument( + '-n', '--no-redirect-stdio', + action='store_true', default=None, + help="Do not redirect stdout/stderr to the log" + ) def generate_files(self, config): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): - with open(log_config, "wb") as log_config_file: + log_file = self.abspath("homeserver.log") + with open(log_config, "w") as log_config_file: log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) + DEFAULT_LOG_CONFIG.substitute(log_file=log_file) ) - def setup_logging(self): - setup_logging(self.log_config, self.log_file, self.verbosity) +def setup_logging(config, use_worker_options=False): + """ Set up python logging + + Args: + config (LoggingConfig | synapse.config.workers.WorkerConfig): + configuration data + + use_worker_options (bool): True to use 'worker_log_config' and + 'worker_log_file' options instead of 'log_config' and 'log_file'. + """ + log_config = (config.worker_log_config if use_worker_options + else config.log_config) + log_file = (config.worker_log_file if use_worker_options + else config.log_file) -def setup_logging(log_config=None, log_file=None, verbosity=None): log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" ) - if log_config is None: + if log_config is None: + # We don't have a logfile, so fall back to the 'verbosity' param from + # the config or cmdline. (Note that we generate a log config for new + # installs, so this will be an unusual case) level = logging.INFO level_for_storage = logging.INFO - if verbosity: + if config.verbosity: level = logging.DEBUG - if verbosity > 1: + if config.verbosity > 1: level_for_storage = logging.DEBUG - # FIXME: we need a logging.WARN for a -q quiet option logger = logging.getLogger('') logger.setLevel(level) - logging.getLogger('synapse.storage').setLevel(level_for_storage) + logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: @@ -153,24 +170,37 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): logger.info("Closing log file due to SIGHUP") handler.doRollover() logger.info("Opened new log file due to SIGHUP") - - # TODO(paul): obviously this is a terrible mechanism for - # stealing SIGHUP, because it means no other part of synapse - # can use it instead. If we want to catch SIGHUP anywhere - # else as well, I'd suggest we find a nicer way to broadcast - # it around. - if getattr(signal, "SIGHUP"): - signal.signal(signal.SIGHUP, sighup) else: handler = logging.StreamHandler() + + def sighup(signum, stack): + pass + handler.setFormatter(formatter) handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) else: - with open(log_config, 'r') as f: - logging.config.dictConfig(yaml.load(f)) + def load_log_config(): + with open(log_config, 'r') as f: + logging.config.dictConfig(yaml.load(f)) + + def sighup(signum, stack): + # it might be better to use a file watcher or something for this. + logging.info("Reloading log config from %s due to SIGHUP", + log_config) + load_log_config() + + load_log_config() + + # TODO(paul): obviously this is a terrible mechanism for + # stealing SIGHUP, because it means no other part of synapse + # can use it instead. If we want to catch SIGHUP anywhere + # else as well, I'd suggest we find a nicer way to broadcast + # it around. + if getattr(signal, "SIGHUP"): + signal.signal(signal.SIGHUP, sighup) # It's critical to point twisted's internal logging somewhere, otherwise it # stacks up and leaks kup to 64K object; @@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None): # # However this may not be too much of a problem if we are just writing to a file. observer = STDLibLogObserver() - globalLogBeginner.beginLoggingTo([observer]) + globalLogBeginner.beginLoggingTo( + [observer], + redirectStandardIO=not config.no_redirect_stdio, + ) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83762d089a..6602c5b4c7 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,44 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +from ._base import Config -import importlib +from synapse.util.module_loader import load_module + +LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] + providers = [] # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) - self.ldap_enabled = ldap_config.get("enabled", False) - if self.ldap_enabled: - from ldap_auth_provider import LdapAuthProvider - parsed_config = LdapAuthProvider.parse_config(ldap_config) - self.password_providers.append((LdapAuthProvider, parsed_config)) + if ldap_config.get("enabled", False): + providers.append({ + 'module': LDAP_PROVIDER, + 'config': ldap_config, + }) - providers = config.get("password_providers", []) + providers.extend(config.get("password_providers", [])) for provider in providers: + mod_name = provider['module'] + # This is for backwards compat when the ldap auth provider resided # in this package. - if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": - from ldap_auth_provider import LdapAuthProvider - provider_class = LdapAuthProvider - else: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": + mod_name = LDAP_PROVIDER + + (provider_class, provider_config) = load_module({ + "module": mod_name, + "config": provider['config'], + }) - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) self.password_providers.append((provider_class, provider_config)) def default_config(self, **kwargs): diff --git a/synapse/config/push.py b/synapse/config/push.py new file mode 100644 index 0000000000..b7e0d46afa --- /dev/null +++ b/synapse/config/push.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class PushConfig(Config): + def read_config(self, config): + push_config = config.get("push", {}) + self.push_include_content = push_config.get("include_content", True) + + # There was a a 'redact_content' setting but mistakenly read from the + # 'email'section'. Check for the flag in the 'push' section, and log, + # but do not honour it to avoid nasty surprises when people upgrade. + if push_config.get("redact_content") is not None: + print( + "The push.redact_content content option has never worked. " + "Please set push.include_content if you want this behaviour" + ) + + # Now check for the one in the 'email' section and honour it, + # with a warning. + push_config = config.get("email", {}) + redact_content = push_config.get("redact_content") + if redact_content is not None: + print( + "The 'email.redact_content' option is deprecated: " + "please set push.include_content instead" + ) + self.push_include_content = not redact_content + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + + # For modern android devices the notification content will still appear + # because it is loaded by the app. iPhone, however will send a + # notification saying only that a message arrived and who it came from. + # + #push: + # include_content: true + """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 87e500c97a..c5384b3ad4 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -31,6 +31,8 @@ class RegistrationConfig(Config): strtobool(str(config["disable_registration"])) ) + self.registrations_require_3pid = config.get("registrations_require_3pid", []) + self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -41,6 +43,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -50,13 +54,32 @@ class RegistrationConfig(Config): # Enable registration for new users. enable_registration: False + # The user must provide all of the below types of 3PID when registering. + # + # registrations_require_3pid: + # - email + # - msisdn + + # Mandate that users are only allowed to associate certain formats of + # 3PIDs with accounts on this server. + # + # allowed_local_3pids: + # - medium: email + # pattern: ".*@matrix\\.org" + # - medium: email + # pattern: ".*@vector\\.im" + # - medium: msisdn + # pattern: "\\+44" + # If set, allows registration by anyone who also has the shared # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. - # The default number of rounds is 12. + # The default number is 12 (which equates to 2^12 rounds). + # N.B. that increasing this will exponentially increase the time required + # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins. bcrypt_rounds: 12 # Allows users to register as guests without a password/email/etc, and @@ -69,6 +92,12 @@ class RegistrationConfig(Config): trusted_third_party_id_servers: - matrix.org - vector.im + - riot.im + + # Users who register on this homeserver will automatically be joined + # to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c6f57168e..25ea77738a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -16,6 +16,8 @@ from ._base import Config, ConfigError from collections import namedtuple +from synapse.util.module_loader import load_module + MISSING_NETADDR = ( "Missing netaddr library. This is required for URL preview API." @@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) +MediaStorageProviderConfig = namedtuple( + "MediaStorageProviderConfig", ( + "store_local", # Whether to store newly uploaded local files + "store_remote", # Whether to store newly downloaded remote files + "store_synchronous", # Whether to wait for successful storage for local uploads + ), +) + def parse_thumbnail_requirements(thumbnail_sizes): """ Takes a list of dictionaries with "width", "height", and "method" keys @@ -70,7 +80,64 @@ class ContentRepositoryConfig(Config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_spider_size = self.parse_size(config["max_spider_size"]) + self.media_store_path = self.ensure_directory(config["media_store_path"]) + + backup_media_store_path = config.get("backup_media_store_path") + + synchronous_backup_media_store = config.get( + "synchronous_backup_media_store", False + ) + + storage_providers = config.get("media_storage_providers", []) + + if backup_media_store_path: + if storage_providers: + raise ConfigError( + "Cannot use both 'backup_media_store_path' and 'storage_providers'" + ) + + storage_providers = [{ + "module": "file_system", + "store_local": True, + "store_synchronous": synchronous_backup_media_store, + "store_remote": True, + "config": { + "directory": backup_media_store_path, + } + }] + + # This is a list of config that can be used to create the storage + # providers. The entries are tuples of (Class, class_config, + # MediaStorageProviderConfig), where Class is the class of the provider, + # the class_config the config to pass to it, and + # MediaStorageProviderConfig are options for StorageProviderWrapper. + # + # We don't create the storage providers here as not all workers need + # them to be started. + self.media_storage_providers = [] + + for provider_config in storage_providers: + # We special case the module "file_system" so as not to need to + # expose FileStorageProviderBackend + if provider_config["module"] == "file_system": + provider_config["module"] = ( + "synapse.rest.media.v1.storage_provider" + ".FileStorageProviderBackend" + ) + + provider_class, parsed_config = load_module(provider_config) + + wrapper_config = MediaStorageProviderConfig( + provider_config.get("store_local", False), + provider_config.get("store_remote", False), + provider_config.get("store_synchronous", False), + ) + + self.media_storage_providers.append( + (provider_class, parsed_config, wrapper_config,) + ) + self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( @@ -115,6 +182,20 @@ class ContentRepositoryConfig(Config): # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" + # Media storage providers allow media to be stored in different + # locations. + # media_storage_providers: + # - module: file_system + # # Whether to write new local files. + # store_local: false + # # Whether to write new remote media + # store_remote: false + # # Whether to block upload requests waiting for write to this + # # provider to complete + # store_synchronous: false + # config: + # directory: /mnt/some/other/directory + # Directory where in-progress uploads are stored. uploads_path: "%(uploads_path)s" diff --git a/synapse/config/server.py b/synapse/config/server.py index 1f9999d57a..8f0b6d1f28 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,12 +30,42 @@ class ServerConfig(Config): self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") + self.cpu_affinity = config.get("cpu_affinity") # Whether to send federation traffic out in this process. This only # applies to some federation traffic, and so shouldn't be used to # "disable" federation self.send_federation = config.get("send_federation", True) + # Whether to update the user directory or not. This should be set to + # false only if we are updating the user directory in a worker + self.update_user_directory = config.get("update_user_directory", True) + + # whether to enable the media repository endpoints. This should be set + # to false if the media repository is running as a separate endpoint; + # doing so ensures that we will not run cache cleanup jobs on the + # master, potentially causing inconsistency. + self.enable_media_repo = config.get("enable_media_repo", True) + + self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + + # Whether we should block invites sent to users on this server + # (other than those sent by local server admins) + self.block_non_admin_invites = config.get( + "block_non_admin_invites", False, + ) + + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None + federation_domain_whitelist = config.get( + "federation_domain_whitelist", None + ) + # turn the whitelist into a hash for speed of lookup + if federation_domain_whitelist is not None: + self.federation_domain_whitelist = {} + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -141,9 +172,36 @@ class ServerConfig(Config): # When running as a daemon, the file to store the pid in pid_file: %(pid_file)s + # CPU affinity mask. Setting this restricts the CPUs on which the + # process will be scheduled. It is represented as a bitmask, with the + # lowest order bit corresponding to the first logical CPU and the + # highest order bit corresponding to the last logical CPU. Not all CPUs + # may exist on a given system but a mask may specify more CPUs than are + # present. + # + # For example: + # 0x00000001 is processor #0, + # 0x00000003 is processors #0 and #1, + # 0xFFFFFFFF is all processors (#0 through #31). + # + # Pinning a Python process to a single CPU is desirable, because Python + # is inherently single-threaded due to the GIL, and can suffer a + # 30-40%% slowdown due to cache blow-out and thread context switching + # if the scheduler happens to schedule the underlying threads across + # different cores. See + # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. + # + # cpu_affinity: 0xFFFFFFFF + # Whether to serve a web client from the HTTP/HTTPS root resource. web_client: True + # The root directory to server for the above web client. + # If left undefined, synapse will serve the matrix-angular-sdk web client. + # Make sure matrix-angular-sdk is installed with pip if web_client is True + # and web_client_location is undefined + # web_client_location: "/path/to/web/root" + # The public-facing base URL for the client API (not including _matrix/...) # public_baseurl: https://example.com:8448/ @@ -155,6 +213,25 @@ class ServerConfig(Config): # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] + # Set the limit on the returned events in the timeline in the get + # and sync operations. The default value is -1, means no upper limit. + # filter_timeline_limit: 5000 + + # Whether room invites to users on this server should be blocked + # (except those sent by local server admins). The default is False. + # block_non_admin_invites: True + + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + # federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: @@ -165,13 +242,12 @@ class ServerConfig(Config): port: %(bind_port)s # Local addresses to listen on. - # This will listen on all IPv4 addresses by default. + # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6 + # addresses by default. For most other OSes, this will only listen + # on IPv6. bind_addresses: + - '::' - '0.0.0.0' - # Uncomment to listen on all IPv6 interfaces - # N.B: On at least Linux this will also listen on all IPv4 - # addresses, so you will need to comment out the line above. - # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -198,11 +274,18 @@ class ServerConfig(Config): - names: [federation] # Federation APIs compress: false + # optional list of additional endpoints which can be loaded via + # dynamic modules + # additional_resources: + # "/_matrix/my/custom/endpoint": + # module: my_module.CustomRequestHandler + # config: {} + # Unsecure HTTP listener, # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] type: http x_forwarded: false @@ -216,7 +299,7 @@ class ServerConfig(Config): # Turn on the twisted ssh manhole service on localhost on the given # port. # - port: 9000 - # bind_address: 127.0.0.1 + # bind_addresses: ['::1', '127.0.0.1'] # type: manhole """ % locals() @@ -254,7 +337,7 @@ def read_gc_thresholds(thresholds): return ( int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), ) - except: + except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" ) diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py new file mode 100644 index 0000000000..3fec42bdb0 --- /dev/null +++ b/synapse/config/spam_checker.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.module_loader import load_module + +from ._base import Config + + +class SpamCheckerConfig(Config): + def read_config(self, config): + self.spam_checker = None + + provider = config.get("spam_checker", None) + if provider is not None: + self.spam_checker = load_module(provider) + + def default_config(self, **kwargs): + return """\ + # spam_checker: + # module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + """ diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..b66154bc7c 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -96,7 +96,7 @@ class TlsConfig(Config): # certificates returned by this server match one of the fingerprints. # # Synapse automatically adds the fingerprint of its own certificate - # to the list. So if federation traffic is handle directly by synapse + # to the list. So if federation traffic is handled directly by synapse # then no modification to the list is required. # # If synapse is run behind a load balancer that handles the TLS then it @@ -109,6 +109,12 @@ class TlsConfig(Config): # key. It may be necessary to publish the fingerprints of a new # certificate and wait until the "valid_until_ts" of the previous key # responses have passed before deploying it. + # + # You can calculate a fingerprint from a given TLS listener via: + # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | + # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' + # or by checking matrix.org/federationtester/api/report?server_name=$host + # tls_fingerprints: [] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ % locals() @@ -126,8 +132,8 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): - with open(tls_private_key_path, "w") as private_key_file: + if not self.path_exists(tls_private_key_path): + with open(tls_private_key_path, "wb") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) private_key_pem = crypto.dump_privatekey( @@ -141,8 +147,8 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): - with open(tls_certificate_path, "w") as certificate_file: + if not self.path_exists(tls_certificate_path): + with open(tls_certificate_path, "wb") as certificate_file: cert = crypto.X509() subject = cert.get_subject() subject.CN = config["server_name"] @@ -159,7 +165,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py new file mode 100644 index 0000000000..38e8947843 --- /dev/null +++ b/synapse/config/user_directory.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class UserDirectoryConfig(Config): + """User Directory Configuration + Configuration for the behaviour of the /user_directory API + """ + + def read_config(self, config): + self.user_directory_search_all_users = False + user_directory_config = config.get("user_directory", None) + if user_directory_config: + self.user_directory_search_all_users = ( + user_directory_config.get("search_all_users", False) + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # User Directory configuration + # + # 'search_all_users' defines whether to search all users visible to your HS + # when searching the user directory, rather than limiting to users visible + # in public rooms. Defaults to false. If you set it True, you'll have to run + # UPDATE user_directory_stream_pos SET stream_id = NULL; + # on your database to tell it to rebuild the user_directory search indexes. + # + #user_directory: + # search_all_users: false + """ diff --git a/synapse/config/voip.py b/synapse/config/voip.py index eeb693027b..3a4e16fa96 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -23,6 +23,7 @@ class VoipConfig(Config): self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) + self.turn_allow_guests = config.get("turn_allow_guests", True) def default_config(self, **kwargs): return """\ @@ -41,4 +42,11 @@ class VoipConfig(Config): # How long generated TURN credentials last turn_user_lifetime: "1h" + + # Whether guests should be allowed to use the TURN server. + # This defaults to True, otherwise VoIP will be unreliable for guests. + # However, it does introduce a slight security risk as it allows users to + # connect to arbitrary endpoints without having first signed up for a + # valid account (e.g. by passing a CAPTCHA). + turn_allow_guests: True """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b165c67ee7..80baf0ce0e 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -23,12 +23,30 @@ class WorkerConfig(Config): def read_config(self, config): self.worker_app = config.get("worker_app") + + # Canonicalise worker_app so that master always has None + if self.worker_app == "synapse.app.homeserver": + self.worker_app = None + self.worker_listeners = config.get("worker_listeners") self.worker_daemonize = config.get("worker_daemonize") self.worker_pid_file = config.get("worker_pid_file") self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") - self.worker_replication_url = config.get("worker_replication_url") + + # The host used to connect to the main synapse + self.worker_replication_host = config.get("worker_replication_host", None) + + # The port on the main synapse for TCP replication + self.worker_replication_port = config.get("worker_replication_port", None) + + # The port on the main synapse for HTTP replication endpoint + self.worker_replication_http_port = config.get("worker_replication_http_port") + + self.worker_name = config.get("worker_name", self.worker_app) + + self.worker_main_http_uri = config.get("worker_main_http_uri", None) + self.worker_cpu_affinity = config.get("worker_cpu_affinity") if self.worker_listeners: for listener in self.worker_listeners: diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index aad4752fe7..0397f73ab4 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -13,8 +13,8 @@ # limitations under the License. from twisted.internet import ssl -from OpenSSL import SSL -from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName +from OpenSSL import SSL, crypto +from twisted.internet._sslverify import _defaultCurveName import logging @@ -32,9 +32,10 @@ class ServerContextFactory(ssl.ContextFactory): @staticmethod def configure_context(context, config): try: - _ecCurve = _OpenSSLECCurve(_defaultCurveName) - _ecCurve.addECKeyToContext(context) - except: + _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) + context.set_tmp_ecdh(_ecCurve) + + except Exception: logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate_chain_file(config.tls_certificate_file) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index ec7711ba7d..aaa3efaca3 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event, hash_algorithm) logger.debug("Expecting hash: %s", encode_base64(expected_hash)) - if name not in event.hashes: + + # some malformed events lack a 'hashes'. Protect against it being missing + # or a weird type by basically treating it the same as an unhashed event. + hashes = event.get("hashes") + if not isinstance(hashes, dict): + raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + + if name not in hashes: raise SynapseError( 400, "Algorithm %s not in hashes %s" % ( - name, list(event.hashes), + name, list(hashes), ), Codes.UNAUTHORIZED, ) - message_hash_base64 = event.hashes[name] + message_hash_base64 = hashes[name] try: message_hash_bytes = decode_base64(message_hash_base64) - except: + except Exception: raise SynapseError( 400, "Invalid base64: %s" % (message_hash_base64,), diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c2bd64d6c2..f1fd488b90 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from synapse.util import logcontext from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import ( - preserve_context_over_fn, preserve_context_over_deferred -) import simplejson as json import logging @@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - protocol = yield preserve_context_over_fn( - endpoint.connect, factory - ) - server_response, server_certificate = yield preserve_context_over_deferred( - protocol.remote_key - ) - defer.returnValue((server_response, server_certificate)) - return + with logcontext.PreserveLoggingContext(): + protocol = yield endpoint.connect(factory) + server_response, server_certificate = yield protocol.remote_key + defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d7211ee9b3..22ee0fc93f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,12 +16,11 @@ from synapse.crypto.keyclient import fetch_server_key from synapse.api.errors import SynapseError, Codes -from synapse.util.retryutils import get_retry_limiter -from synapse.util import unwrapFirstError -from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError, logcontext from synapse.util.logcontext import ( - preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, - preserve_fn + PreserveLoggingContext, + preserve_fn, + run_in_background, ) from synapse.util.metrics import Measure @@ -58,7 +58,8 @@ Attributes: json_object(dict): The JSON object to verify. deferred(twisted.internet.defer.Deferred): A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched + a verify key has been fetched. The deferreds' callbacks are run with no + logcontext. """ @@ -75,31 +76,41 @@ class Keyring(object): self.perspective_servers = self.config.perspectives self.hs = hs + # map from server name to Deferred. Has an entry for each server with + # an ongoing key download; the Deferred completes once the download + # completes. + # + # These are regular, logcontext-agnostic Deferreds. self.key_downloads = {} def verify_json_for_server(self, server_name, json_object): - return self.verify_json_objects_for_server( - [(server_name, json_object)] - )[0] + return logcontext.make_deferred_yieldable( + self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] + ) def verify_json_objects_for_server(self, server_and_json): - """Bulk verfies signatures of json objects, bulk fetching keys as + """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: server_and_json (list): List of pairs of (server_name, json_object) Returns: - list of deferreds indicating success or failure to verify each - json object's signature for the given server_name. + List<Deferred>: for each input pair, a deferred indicating success + or failure to verify each json object's signature for the given + server_name. The deferreds run their callbacks in the sentinel + logcontext. """ verify_requests = [] for server_name, json_object in server_and_json: - logger.debug("Verifying for %s", server_name) key_ids = signature_ids(json_object, server_name) if not key_ids: + logger.warn("Request from %s: no supported signature keys", + server_name) deferred = defer.fail(SynapseError( 400, "Not signed with a supported algorithm", @@ -108,76 +119,69 @@ class Keyring(object): else: deferred = defer.Deferred() + logger.debug("Verifying for %s with key_ids %s", + server_name, key_ids) + verify_request = VerifyKeyRequest( server_name, key_ids, json_object, deferred ) verify_requests.append(verify_request) - @defer.inlineCallbacks - def handle_key_deferred(verify_request): - server_name = verify_request.server_name - try: - _, key_id, verify_key = yield verify_request.deferred - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.exception( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + run_in_background(self._start_key_lookups, verify_requests) - json_object = verify_request.json_object + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified + handle = preserve_fn(_handle_key_deferred) + return [ + handle(rq) for rq in verify_requests + ] - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + @defer.inlineCallbacks + def _start_key_lookups(self, verify_requests): + """Sets off the key fetches for each verify request - server_to_deferred = { - server_name: defer.Deferred() - for server_name, _ in server_and_json - } + Once each fetch completes, verify_request.deferred will be resolved. - with PreserveLoggingContext(): + Args: + verify_requests (List[VerifyKeyRequest]): + """ + + try: + # create a deferred for each server we're going to look up the keys + # for; we'll resolve them once we have completed our lookups. + # These will be passed into wait_for_previous_lookups to block + # any other lookups until we have finished. + # The deferreds are called with no logcontext. + server_to_deferred = { + rq.server_name: defer.Deferred() + for rq in verify_requests + } # We want to wait for any previous lookups to complete before # proceeding. - wait_on_deferred = self.wait_for_previous_lookups( - [server_name for server_name, _ in server_and_json], + yield self.wait_for_previous_lookups( + [rq.server_name for rq in verify_requests], server_to_deferred, ) # Actually start fetching keys. - wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(verify_requests) - ) + self._get_server_verify_keys(verify_requests) # When we've finished fetching all the keys for a given server_name, # resolve the deferred passed to `wait_for_previous_lookups` so that # any lookups waiting will proceed. + # + # map from server name to a set of request ids server_to_request_ids = {} - def remove_deferreds(res, server_name, verify_request): + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) + + def remove_deferreds(res, verify_request): + server_name = verify_request.server_name request_id = id(verify_request) server_to_request_ids[server_name].discard(request_id) if not server_to_request_ids[server_name]: @@ -187,17 +191,11 @@ class Keyring(object): return res for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - deferred.addBoth(remove_deferreds, server_name, verify_request) - - # Pass those keys to handle_key_deferred so that the json object - # signatures can be verified - return [ - preserve_context_over_fn(handle_key_deferred, verify_request) - for verify_request in verify_requests - ] + verify_request.deferred.addBoth( + remove_deferreds, verify_request, + ) + except Exception: + logger.exception("Error starting key lookups") @defer.inlineCallbacks def wait_for_previous_lookups(self, server_names, server_to_deferred): @@ -206,7 +204,13 @@ class Keyring(object): Args: server_names (list): list of server_names we want to lookup server_to_deferred (dict): server_name to deferred which gets - resolved once we've finished looking up keys for that server + resolved once we've finished looking up keys for that server. + The Deferreds should be regular twisted ones which call their + callbacks with no logcontext. + + Returns: a Deferred which resolves once all key lookups for the given + servers have completed. Follows the synapse rules of logcontext + preservation. """ while True: wait_on = [ @@ -220,19 +224,23 @@ class Keyring(object): else: break + def rm(r, server_name_): + self.key_downloads.pop(server_name_, None) + return r + for server_name, deferred in server_to_deferred.items(): - d = ObservableDeferred(preserve_context_over_deferred(deferred)) - self.key_downloads[server_name] = d + self.key_downloads[server_name] = deferred + deferred.addBoth(rm, server_name) - def rm(r, server_name): - self.key_downloads.pop(server_name, None) - return r + def _get_server_verify_keys(self, verify_requests): + """Tries to find at least one key for each verify request - d.addBoth(rm, server_name) + For each verify_request, verify_request.deferred is called back with + params (server_name, key_id, VerifyKey) if a key is found, or errbacked + with a SynapseError if none of the keys are found. - def get_server_verify_keys(self, verify_requests): - """Takes a dict of KeyGroups and tries to find at least one key for - each group. + Args: + verify_requests (list[VerifyKeyRequest]): list of verify requests """ # These are functions that produce keys given a list of key ids @@ -245,8 +253,11 @@ class Keyring(object): @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): + # dict[str, dict[str, VerifyKey]]: results so far. + # map server_name -> key_id -> VerifyKey merged_results = {} + # dict[str, set(str)]: keys to fetch for each server missing_keys = {} for verify_request in verify_requests: missing_keys.setdefault(verify_request.server_name, set()).update( @@ -290,33 +301,46 @@ class Keyring(object): if not missing_keys: break - for verify_request in requests_missing_keys.values(): - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + with PreserveLoggingContext(): + for verify_request in requests_missing_keys: + verify_request.deferred.errback(SynapseError( + 401, + "No key for %s with id %s" % ( + verify_request.server_name, verify_request.key_ids, + ), + Codes.UNAUTHORIZED, + )) def on_err(err): - for verify_request in verify_requests: - if not verify_request.deferred.called: - verify_request.deferred.errback(err) + with PreserveLoggingContext(): + for verify_request in verify_requests: + if not verify_request.deferred.called: + verify_request.deferred.errback(err) - do_iterations().addErrback(on_err) + run_in_background(do_iterations).addErrback(on_err) @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): - res = yield preserve_context_over_deferred(defer.gatherResults( + """ + + Args: + server_name_and_key_ids (list[(str, iterable[str])]): + list of (server_name, iterable[key_id]) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from + server_name -> key_id -> VerifyKey + """ + res = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.get_server_verify_keys)( - server_name, key_ids + run_in_background( + self.store.get_server_verify_keys, + server_name, key_ids, ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(dict(res)) @@ -333,17 +357,17 @@ class Keyring(object): logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) defer.returnValue({}) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(p_name, p_keys) + run_in_background(get_key, p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) union_of_keys = {} for result in results: @@ -356,40 +380,34 @@ class Keyring(object): def get_keys_from_server(self, server_name_and_key_ids): @defer.inlineCallbacks def get_key(server_name, key_ids): - limiter = yield get_retry_limiter( - server_name, - self.clock, - self.store, - ) - with limiter: - keys = None - try: - keys = yield self.get_server_verify_key_v2_direct( - server_name, key_ids - ) - except Exception as e: - logger.info( - "Unable to get key %r for %r directly: %s %s", - key_ids, server_name, - type(e).__name__, str(e.message), - ) + keys = None + try: + keys = yield self.get_server_verify_key_v2_direct( + server_name, key_ids + ) + except Exception as e: + logger.info( + "Unable to get key %r for %r directly: %s %s", + key_ids, server_name, + type(e).__name__, str(e), + ) - if not keys: - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) - keys = {server_name: keys} + keys = {server_name: keys} defer.returnValue(keys) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(server_name, key_ids) + run_in_background(get_key, server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) merged = {} for result in results: @@ -466,9 +484,10 @@ class Keyring(object): for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(response_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -476,7 +495,7 @@ class Keyring(object): for server_name, response_keys in keys.items() ], consumeErrors=True - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(keys) @@ -524,9 +543,10 @@ class Keyring(object): keys.update(response_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, @@ -534,7 +554,7 @@ class Keyring(object): for key_server_name, verify_keys in keys.items() ], consumeErrors=True - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(keys) @@ -600,9 +620,10 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_keys_json)( + run_in_background( + self.store.store_server_keys_json, server_name=server_name, key_id=key_id, from_server=server_name, @@ -613,7 +634,7 @@ class Keyring(object): for key_id in updated_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) results[server_name] = response_keys @@ -691,7 +712,6 @@ class Keyring(object): defer.returnValue(verify_keys) - @defer.inlineCallbacks def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server Args: @@ -702,12 +722,57 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - yield preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_verify_key)( + run_in_background( + self.store.store_server_verify_key, server_name, server_name, key.time_added, key ) for key_id, key in verify_keys.items() ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) + + +@defer.inlineCallbacks +def _handle_key_deferred(verify_request): + server_name = verify_request.server_name + try: + with PreserveLoggingContext(): + _, key_id, verify_key = yield verify_request.deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, verify_request.key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = verify_request.json_object + + logger.debug("Got key %s %s:%s for server %s, verifying" % ( + key_id, verify_key.alg, verify_key.version, server_name, + )) + try: + verify_signed_json(json_object, server_name, verify_key) + except Exception: + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s" % ( + server_name, verify_key.alg, verify_key.version + ), + Codes.UNAUTHORIZED, + ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4096c606f1..cd5627e36a 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events): # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) + 403, "You cannot unban user %s." % (target_user_id,) ) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) @@ -443,12 +443,12 @@ def _check_power_levels(event, auth_events): for k, v in user_list.items(): try: UserID.from_string(k) - except: + except Exception: raise SynapseError(400, "Not a valid user_id: %s" % (k,)) try: int(v) - except: + except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) key = (event.type, event.state_key, ) @@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events): ("invite", None), ] - old_list = current_state.content.get("users") + old_list = current_state.content.get("users", {}) for user in set(old_list.keys() + user_list.keys()): levels_to_check.append( (user, "users") ) - old_list = current_state.content.get("events") - new_list = event.content.get("events") + old_list = current_state.content.get("events", {}) + new_list = event.content.get("events", {}) for ev_id in set(old_list.keys() + new_list.keys()): levels_to_check.append( (ev_id, "events") diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index e673e96cc0..c3ff85c49a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -47,14 +47,26 @@ class _EventInternalMetadata(object): def _event_dict_property(key): + # We want to be able to use hasattr with the event dict properties. + # However, (on python3) hasattr expects AttributeError to be raised. Hence, + # we need to transform the KeyError into an AttributeError def getter(self): - return self._event_dict[key] + try: + return self._event_dict[key] + except KeyError: + raise AttributeError(key) def setter(self, v): - self._event_dict[key] = v + try: + self._event_dict[key] = v + except KeyError: + raise AttributeError(key) def delete(self): - del self._event_dict[key] + try: + del self._event_dict[key] + except KeyError: + raise AttributeError(key) return property( getter, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 365fd96bd2..13fbba68c0 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -55,7 +55,7 @@ class EventBuilderFactory(object): local_part = str(int(self.clock.time())) + i + random_string(5) - e_id = EventID.create(local_part, self.hostname) + e_id = EventID(local_part, self.hostname) return e_id.to_string() diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 11605b34a3..8e684d91b5 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,17 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from frozendict import frozendict + class EventContext(object): + """ + Attributes: + current_state_ids (dict[(str, str), str]): + The current state map including the current event. + (type, state_key) -> event_id + + prev_state_ids (dict[(str, str), str]): + The current state map excluding the current event. + (type, state_key) -> event_id + + state_group (int|None): state group id, if the state has been stored + as a state group. This is usually only None if e.g. the event is + an outlier. + rejected (bool|str): A rejection reason if the event was rejected, else + False + + push_actions (list[(str, list[object])]): list of (user_id, actions) + tuples + + prev_group (int): Previously persisted state group. ``None`` for an + outlier. + delta_ids (dict[(str, str), str]): Delta from ``prev_group``. + (type, state_key) -> event_id. ``None`` for an outlier. + + prev_state_events (?): XXX: is this ever set to anything other than + the empty list? + """ + __slots__ = [ "current_state_ids", "prev_state_ids", "state_group", "rejected", - "push_actions", "prev_group", "delta_ids", "prev_state_events", + "app_service", ] def __init__(self): @@ -34,7 +66,6 @@ class EventContext(object): self.state_group = None self.rejected = False - self.push_actions = [] # A previously persisted state group and a delta between that # and this state. @@ -42,3 +73,100 @@ class EventContext(object): self.delta_ids = None self.prev_state_events = None + + self.app_service = None + + def serialize(self, event): + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Args: + event (FrozenEvent): The event that this context relates to + + Returns: + dict + """ + + # We don't serialize the full state dicts, instead they get pulled out + # of the DB on the other side. However, the other side can't figure out + # the prev_state_ids, so if we're a state event we include the event + # id that we replaced in the state. + if event.is_state(): + prev_state_id = self.prev_state_ids.get((event.type, event.state_key)) + else: + prev_state_id = None + + return { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None + } + + @staticmethod + @defer.inlineCallbacks + def deserialize(store, input): + """Converts a dict that was produced by `serialize` back into a + EventContext. + + Args: + store (DataStore): Used to convert AS ID to AS object + input (dict): A dict produced by `serialize` + + Returns: + EventContext + """ + context = EventContext() + context.state_group = input["state_group"] + context.rejected = input["rejected"] + context.prev_group = input["prev_group"] + context.delta_ids = _decode_state_dict(input["delta_ids"]) + context.prev_state_events = input["prev_state_events"] + + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + prev_state_id = input["prev_state_id"] + event_type = input["event_type"] + event_state_key = input["event_state_key"] + + context.current_state_ids = yield store.get_state_ids_for_group( + context.state_group, + ) + if prev_state_id and event_state_key: + context.prev_state_ids = dict(context.current_state_ids) + context.prev_state_ids[(event_type, event_state_key)] = prev_state_id + else: + context.prev_state_ids = context.current_state_ids + + app_service_id = input["app_service_id"] + if app_service_id: + context.app_service = store.get_app_service_by_id(app_service_id) + + defer.returnValue(context) + + +def _encode_state_dict(state_dict): + """Since dicts of (type, state_key) -> event_id cannot be serialized in + JSON we need to convert them to a form that can. + """ + if state_dict is None: + return None + + return [ + (etype, state_key, v) + for (etype, state_key), v in state_dict.iteritems() + ] + + +def _decode_state_dict(input): + """Decodes a state dict encoded using `_encode_state_dict` above + """ + if input is None: + return None + + return frozendict({(etype, state_key,): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py new file mode 100644 index 0000000000..633e068eb8 --- /dev/null +++ b/synapse/events/spamcheck.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class SpamChecker(object): + def __init__(self, hs): + self.spam_checker = None + + module = None + config = None + try: + module, config = hs.config.spam_checker + except Exception: + pass + + if module is not None: + self.spam_checker = module(config=config) + + def check_event_for_spam(self, event): + """Checks if a given event is considered "spammy" by this server. + + If the server considers an event spammy, then it will be rejected if + sent by a local user. If it is sent by a user on another server, then + users receive a blank event. + + Args: + event (synapse.events.EventBase): the event to be checked + + Returns: + bool: True if the event is spammy. + """ + if self.spam_checker is None: + return False + + return self.spam_checker.check_event_for_spam(event) + + def user_may_invite(self, inviter_userid, invitee_userid, room_id): + """Checks if a given user may send an invite + + If this method returns false, the invite will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may send an invite, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + + def user_may_create_room(self, userid): + """Checks if a given user may create a room + + If this method returns false, the creation request will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may create a room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room(userid) + + def user_may_create_room_alias(self, userid, room_alias): + """Checks if a given user may create a room alias + + If this method returns false, the association request will be rejected. + + Args: + userid (string): The sender's user ID + room_alias (string): The alias to be created + + Returns: + bool: True if the user may create a room alias, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room_alias(userid, room_alias) + + def user_may_publish_room(self, userid, room_id): + """Checks if a given user may publish a room to the directory + + If this method returns false, the publish request will be rejected. + + Args: + userid (string): The sender's user ID + room_id (string): The ID of the room that would be published + + Returns: + bool: True if the user may publish the room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_publish_room(userid, room_id) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 5bbaef8187..824f4a42e3 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -225,7 +225,22 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None, only_event_fields=None): + token_id=None, only_event_fields=None, is_invite=False): + """Serialize event for clients + + Args: + e (EventBase) + time_now_ms (int) + as_client_event (bool) + event_format + token_id + only_event_fields + is_invite (bool): Whether this is an invite that is being sent to the + invitee + + Returns: + dict + """ # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -251,6 +266,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id + # If this is an invite for somebody else, then we don't care about the + # invite_room_state as that's meant solely for the invitee. Other clients + # will already have the state since they're in the room. + if not is_invite: + d["unsigned"].pop("invite_room_state", None) + if as_client_event: d = event_format(d) diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 2e32d245ba..f5f0bdfca3 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -15,11 +15,3 @@ """ This package includes all the federation specific logic. """ - -from .replication import ReplicationLayer - - -def initialize_http_replication(hs): - transport = hs.get_federation_transport_client() - - return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 2339cc9034..4cc98a3fe8 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -12,28 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import six -from twisted.internet import defer - -from synapse.events.utils import prune_event - +from synapse.api.constants import MAX_DEPTH +from synapse.api.errors import SynapseError, Codes from synapse.crypto.event_signing import check_event_content_hash - -from synapse.api.errors import SynapseError - -from synapse.util import unwrapFirstError -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred - -import logging - +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event +from synapse.http.servlet import assert_params_in_request +from synapse.util import unwrapFirstError, logcontext +from twisted.internet import defer logger = logging.getLogger(__name__) class FederationBase(object): def __init__(self, hs): - pass + self.hs = hs + + self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.spam_checker = hs.get_spam_checker() + self.store = hs.get_datastore() + self._clock = hs.get_clock() @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, @@ -57,56 +60,52 @@ class FederationBase(object): """ deferreds = self._check_sigs_and_hashes(pdus) - def callback(pdu): - return pdu - - def errback(failure, pdu): - failure.trap(SynapseError) - return None + @defer.inlineCallbacks + def handle_check_result(pdu, deferred): + try: + res = yield logcontext.make_deferred_yieldable(deferred) + except SynapseError: + res = None - def try_local_db(res, pdu): if not res: # Check local db. - return self.store.get_event( + res = yield self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) - return res - def try_remote(res, pdu): if not res and pdu.origin != origin: - return self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - outlier=outlier, - timeout=10000, - ).addErrback(lambda e: None) - return res - - def warn(res, pdu): + try: + res = yield self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + outlier=outlier, + timeout=10000, + ) + except SynapseError: + pass + if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) - return res - for pdu, deferred in zip(pdus, deferreds): - deferred.addCallbacks( - callback, errback, errbackArgs=[pdu] - ).addCallback( - try_local_db, pdu - ).addCallback( - try_remote, pdu - ).addCallback( - warn, pdu - ) + defer.returnValue(res) - valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( - deferreds, - consumeErrors=True - )).addErrback(unwrapFirstError) + handle = logcontext.preserve_fn(handle_check_result) + deferreds2 = [ + handle(pdu, deferred) + for pdu, deferred in zip(pdus, deferreds) + ] + + valid_pdus = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + deferreds2, + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -114,15 +113,24 @@ class FederationBase(object): defer.returnValue([p for p in valid_pdus if p]) def _check_sigs_and_hash(self, pdu): - return self._check_sigs_and_hashes([pdu])[0] + return logcontext.make_deferred_yieldable( + self._check_sigs_and_hashes([pdu])[0], + ) def _check_sigs_and_hashes(self, pdus): - """Throws a SynapseError if a PDU does not have the correct - signatures. + """Checks that each of the received events is correctly signed by the + sending server. + + Args: + pdus (list[FrozenEvent]): the events to be checked Returns: - FrozenEvent: Either the given event or it redacted if it failed the - content hash check. + list[Deferred]: for each input event, a deferred which: + * returns the original event if the checks pass + * returns a redacted version of the event (if the signature + matched but the hash did not) + * throws a SynapseError if the signature check failed. + The deferreds run their callbacks in the sentinel logcontext. """ redacted_pdus = [ @@ -130,26 +138,38 @@ class FederationBase(object): for pdu in pdus ] - deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ + deferreds = self.keyring.verify_json_objects_for_server([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) + ctx = logcontext.LoggingContext.current_context() + def callback(_, pdu, redacted): - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted - return pdu + with logcontext.PreserveLoggingContext(ctx): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + if self.spam_checker.check_event_for_spam(pdu): + logger.warn( + "Event contains spam, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + return pdu def errback(failure, pdu): failure.trap(SynapseError) - logger.warn( - "Signature check failed for %s", - pdu.event_id, - ) + with logcontext.PreserveLoggingContext(ctx): + logger.warn( + "Signature check failed for %s", + pdu.event_id, + ) return failure for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): @@ -160,3 +180,40 @@ class FederationBase(object): ) return deferreds + + +def event_from_pdu_json(pdu_json, outlier=False): + """Construct a FrozenEvent from an event json received over federation + + Args: + pdu_json (object): pdu as received over federation + outlier (bool): True to mark this event as an outlier + + Returns: + FrozenEvent + + Raises: + SynapseError: if the pdu is missing required fields or is otherwise + not a valid matrix event + """ + # we could probably enforce a bunch of other fields here (room_id, sender, + # origin, etc etc) + assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) + + depth = pdu_json['depth'] + if not isinstance(depth, six.integer_types): + raise SynapseError(400, "Depth %r not an intger" % (depth, ), + Codes.BAD_JSON) + + if depth < 0: + raise SynapseError(400, "Depth too small", Codes.BAD_JSON) + elif depth > MAX_DEPTH: + raise SynapseError(400, "Depth too large", Codes.BAD_JSON) + + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b5bcfd705a..6163f7c466 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -14,28 +14,30 @@ # limitations under the License. +import copy +import itertools +import logging +import random + +from six.moves import range + from twisted.internet import defer -from .federation_base import FederationBase from synapse.api.constants import Membership - from synapse.api.errors import ( - CodeMessageException, HttpResponseException, SynapseError, + CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError ) -from synapse.util import unwrapFirstError +from synapse.events import builder +from synapse.federation.federation_base import ( + FederationBase, + event_from_pdu_json, +) +import synapse.metrics +from synapse.util import logcontext, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logutils import log_function -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent, builder -import synapse.metrics - -from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination - -import copy -import itertools -import logging -import random - +from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -58,6 +60,7 @@ class FederationClient(FederationBase): self._clear_tried_cache, 60 * 1000, ) self.state = hs.get_state_handler() + self.transport_layer = hs.get_federation_transport_client() def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -88,7 +91,7 @@ class FederationClient(FederationBase): @log_function def make_query(self, destination, query_type, args, - retry_on_dns_fail=False): + retry_on_dns_fail=False, ignore_backoff=False): """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -98,6 +101,8 @@ class FederationClient(FederationBase): handler name used in register_query_handler(). args (dict): Mapping of strings to strings containing the details of the query request. + ignore_backoff (bool): true to ignore the historical backoff data + and try the request anyway. Returns: a Deferred which will eventually yield a JSON object from the @@ -106,7 +111,8 @@ class FederationClient(FederationBase): sent_queries_counter.inc(query_type) return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail + destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, + ignore_backoff=ignore_backoff, ) @log_function @@ -181,15 +187,15 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%s", repr(transaction_data)) pdus = [ - self.event_from_pdu_json(p, outlier=False) + event_from_pdu_json(p, outlier=False) for p in transaction_data["pdus"] ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( + pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(pdus) @@ -206,8 +212,7 @@ class FederationClient(FederationBase): Args: destinations (list): Which home servers to query - pdu_origin (str): The home server that originally sent the pdu. - event_id (str) + event_id (str): event to fetch outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` @@ -235,31 +240,24 @@ class FederationClient(FederationBase): continue try: - limiter = yield get_retry_limiter( - destination, - self._clock, - self.store, + transaction_data = yield self.transport_layer.get_event( + destination, event_id, timeout=timeout, ) - with limiter: - transaction_data = yield self.transport_layer.get_event( - destination, event_id, timeout=timeout, - ) - - logger.debug("transaction_data %r", transaction_data) + logger.debug("transaction_data %r", transaction_data) - pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) - for p in transaction_data["pdus"] - ] + pdu_list = [ + event_from_pdu_json(p, outlier=outlier) + for p in transaction_data["pdus"] + ] - if pdu_list and pdu_list[0]: - pdu = pdu_list[0] + if pdu_list and pdu_list[0]: + pdu = pdu_list[0] - # Check signatures are correct. - signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] + # Check signatures are correct. + signed_pdu = yield self._check_sigs_and_hash(pdu) - break + break pdu_attempts[destination] = now @@ -271,6 +269,9 @@ class FederationClient(FederationBase): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e.message) + continue except Exception as e: pdu_attempts[destination] = now @@ -341,11 +342,11 @@ class FederationClient(FederationBase): ) pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + event_from_pdu_json(p, outlier=True) for p in result["pdus"] ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in result.get("auth_chain", []) ] @@ -395,7 +396,7 @@ class FederationClient(FederationBase): seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = seen_events.values() else: - seen_events = yield self.store.have_events(event_ids) + seen_events = yield self.store.have_seen_events(event_ids) signed_events = [] failed_to_fetch = set() @@ -414,18 +415,19 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) - for i in xrange(0, len(missing_events), batch_size): + for i in range(0, len(missing_events), batch_size): batch = set(missing_events[i:i + batch_size]) deferreds = [ - preserve_fn(self.get_pdu)( + run_in_background( + self.get_pdu, destinations=random_server_list(), event_id=e_id, ) for e_id in batch ] - res = yield preserve_context_over_deferred( + res = yield make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: @@ -446,7 +448,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in res["auth_chain"] ] @@ -479,8 +481,13 @@ class FederationClient(FederationBase): content (object): Any additional data to put into the content field of the event. Return: - A tuple of (origin (str), event (object)) where origin is the remote - homeserver which generated the event. + Deferred: resolves to a tuple of (origin (str), event (object)) + where origin is the remote homeserver which generated the event. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a 300/400 code. + + Fails with a ``RuntimeError`` if no servers were reachable. """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -533,6 +540,27 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_join(self, destinations, pdu): + """Sends a join event to one of a list of homeservers. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + pdu (BaseEvent): event to be sent + + Return: + Deferred: resolves to a dict with members ``origin`` (a string + giving the serer the event was sent to, ``state`` (?) and + ``auth_chain``. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a 300/400 code. + + Fails with a ``RuntimeError`` if no servers were reachable. + """ + for destination in destinations: if destination == self.server_name: continue @@ -549,12 +577,12 @@ class FederationClient(FederationBase): logger.debug("Got content: %s", content) state = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("state", []) ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("auth_chain", []) ] @@ -629,7 +657,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - pdu = self.event_from_pdu_json(pdu_dict) + pdu = event_from_pdu_json(pdu_dict) # Check signatures are correct. pdu = yield self._check_sigs_and_hash(pdu) @@ -640,6 +668,26 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_leave(self, destinations, pdu): + """Sends a leave event to one of a list of homeservers. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + This is mostly useful to reject received invites. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + pdu (BaseEvent): event to be sent + + Return: + Deferred: resolves to None. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a non-200 code. + + Fails with a ``RuntimeError`` if no servers were reachable. + """ for destination in destinations: if destination == self.server_name: continue @@ -699,7 +747,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -747,7 +795,7 @@ class FederationClient(FederationBase): ) events = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content.get("events", []) ] @@ -764,15 +812,6 @@ class FederationClient(FederationBase): defer.returnValue(signed_events) - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e922b7ff4a..247ddc89d5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,27 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging - +import simplejson as json from twisted.internet import defer -from .federation_base import FederationBase -from .units import Transaction, Edu +from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError +from synapse.crypto.event_signing import compute_event_signature +from synapse.federation.federation_base import ( + FederationBase, + event_from_pdu_json, +) -from synapse.util.async import Linearizer -from synapse.util.logutils import log_function -from synapse.util.caches.response_cache import ResponseCache -from synapse.events import FrozenEvent -from synapse.types import get_domain_from_id +from synapse.federation.persistence import TransactionActions +from synapse.federation.units import Edu, Transaction import synapse.metrics +from synapse.types import get_domain_from_id +from synapse.util import async +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logutils import log_function -from synapse.api.errors import AuthError, FederationError, SynapseError - -from synapse.crypto.event_signing import compute_event_signature - -import simplejson as json -import logging +from six import iteritems +# when processing incoming transactions, we try to handle multiple rooms in +# parallel, up to this limit. +TRANSACTION_CONCURRENCY_LIMIT = 10 logger = logging.getLogger(__name__) @@ -51,49 +56,18 @@ class FederationServer(FederationBase): super(FederationServer, self).__init__(hs) self.auth = hs.get_auth() + self.handler = hs.get_handlers().federation_handler - self._room_pdu_linearizer = Linearizer("fed_room_pdu") - self._server_linearizer = Linearizer("fed_server") - - # We cache responses to state queries, as they take a while and often - # come in waves. - self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) - - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler - - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type + self._server_linearizer = async.Linearizer("fed_server") + self._transaction_linearizer = async.Linearizer("fed_txn_handler") - handler is invoked as: - result = handler(args) + self.transaction_actions = TransactionActions(self.store) - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) + self.registry = hs.get_federation_registry() - self.query_handlers[query_type] = handler + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) @defer.inlineCallbacks @log_function @@ -110,25 +84,41 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_incoming_transaction(self, transaction_data): + # keep this as early as possible to make the calculated origin ts as + # accurate as possible. + request_time = self._clock.time_msec() + transaction = Transaction(**transaction_data) - received_pdus_counter.inc_by(len(transaction.pdus)) + if not transaction.transaction_id: + raise Exception("Transaction missing transaction_id") + if not transaction.origin: + raise Exception("Transaction missing origin") - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] + logger.debug("[%s] Got transaction", transaction.transaction_id) - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] + # use a linearizer to ensure that we don't process the same transaction + # multiple times in parallel. + with (yield self._transaction_linearizer.queue( + (transaction.origin, transaction.transaction_id), + )): + result = yield self._handle_incoming_transaction( + transaction, request_time, + ) - logger.debug("[%s] Got transaction", transaction.transaction_id) + defer.returnValue(result) + + @defer.inlineCallbacks + def _handle_incoming_transaction(self, transaction, request_time): + """ Process an incoming transaction and return the HTTP response + Args: + transaction (Transaction): incoming transaction + request_time (int): timestamp that the HTTP request arrived at + + Returns: + Deferred[(int, object)]: http response code and body + """ response = yield self.transaction_actions.have_responded(transaction) if response: @@ -141,38 +131,49 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - results = [] - - for pdu in pdu_list: - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if transaction.origin != get_domain_from_id(pdu.event_id): - if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) == 'join' and - self.hs.is_mine_id(pdu.state_key) - ): - logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, transaction.origin - ) - continue - else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, transaction.origin - ) + received_pdus_counter.inc_by(len(transaction.pdus)) + + pdus_by_room = {} + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = request_time - int(p["age"]) + del p["age"] - try: - yield self._handle_new_pdu(transaction.origin, pdu) - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") + event = event_from_pdu_json(p) + room_id = event.room_id + pdus_by_room.setdefault(room_id, []).append(event) + + pdu_results = {} + + # we can process different rooms in parallel (which is useful if they + # require callouts to other servers to fetch missing events), but + # impose a limit to avoid going too crazy with ram/cpu. + @defer.inlineCallbacks + def process_pdus_for_room(room_id): + logger.debug("Processing PDUs for %s", room_id) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + try: + yield self._handle_received_pdu( + transaction.origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + pdu_results[event_id] = {"error": str(e)} + logger.exception("Failed to handle PDU %s", event_id) + + yield async.concurrently_execute( + process_pdus_for_room, pdus_by_room.keys(), + TRANSACTION_CONCURRENCY_LIMIT, + ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): @@ -182,17 +183,16 @@ class FederationServer(FederationBase): edu.content ) - for failure in getattr(transaction, "pdu_failures", []): - logger.info("Got failure %r", failure) - - logger.debug("Returning: %s", str(results)) + pdu_failures = getattr(transaction, "pdu_failures", []) + for failure in pdu_failures: + logger.info("Got failure %r", failure) response = { - "pdus": dict(zip( - (p.event_id for p in pdu_list), results - )), + "pdus": pdu_results, } + logger.debug("Returning: %s", str(response)) + yield self.transaction_actions.set_response( transaction, 200, response @@ -202,16 +202,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - - if edu_type in self.edu_handlers: - try: - yield self.edu_handlers[edu_type](origin, content) - except SynapseError as e: - logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: - logger.exception("Failed to handle edu %r", edu_type) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) + yield self.registry.on_edu(edu_type, origin, content) @defer.inlineCallbacks @log_function @@ -223,15 +214,17 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - result = self._state_resp_cache.get((room_id, event_id)) - if not result: - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.set( - (room_id, event_id), - self._on_context_state_request_compute(room_id, event_id) - ) - else: - resp = yield result + # we grab the linearizer to protect ourselves from servers which hammer + # us. In theory we might already have the response to this query + # in the cache so we could return it without waiting for the linearizer + # - but that's non-trivial to get right, and anyway somewhat defeats + # the point of the linearizer. + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, event_id, + ) defer.returnValue((200, resp)) @@ -300,14 +293,8 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.inc(query_type) - - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type,)) - ) + resp = yield self.registry.on_query(query_type, args) + defer.returnValue((200, resp)) @defer.inlineCallbacks def on_make_join_request(self, room_id, user_id): @@ -317,7 +304,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @@ -325,7 +312,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -345,7 +332,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @@ -382,7 +369,7 @@ class FederationServer(FederationBase): """ with (yield self._server_linearizer.queue((origin, room_id))): auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -437,6 +424,16 @@ class FederationServer(FederationBase): key_id: json.loads(json_bytes) } + logger.info( + "Claimed one-time-keys: %s", + ",".join(( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + )), + ) + defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @@ -497,26 +494,59 @@ class FederationServer(FederationBase): ) @defer.inlineCallbacks - @log_function - def _handle_new_pdu(self, origin, pdu, get_missing=True): + def _handle_received_pdu(self, origin, pdu): + """ Process a PDU received in a federation /send/ transaction. + + If the event is invalid, then this method throws a FederationError. + (The error will then be logged and sent back to the sender (which + probably won't do anything with it), and other events in the + transaction will be processed as normal). + + It is likely that we'll then receive other events which refer to + this rejected_event in their prev_events, etc. When that happens, + we'll attempt to fetch the rejected event again, which will presumably + fail, so those second-generation events will also get rejected. + + Eventually, we get to the point where there are more than 10 events + between any new events and the original rejected event. Since we + only try to backfill 10 events deep on received pdu, we then accept the + new event, possibly introducing a discontinuity in the DAG, with new + forward extremities, so normal service is approximately returned, + until we try to backfill across the discontinuity. - # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu( - origin, pdu.event_id, do_auth=False - ) + Args: + origin (str): server which sent the pdu + pdu (FrozenEvent): received pdu - # FIXME: Currently we fetch an event again when we already have it - # if it has been marked as an outlier. + Returns (Deferred): completes with None - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - ) - if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) - return + Raises: FederationError if the signatures / hash do not match, or + if the event was unacceptable for any other reason (eg, too large, + too many prev_events, couldn't find the prev_events) + """ + # check that it's actually being sent from a valid destination to + # workaround bug #1753 in 0.18.5 and 0.18.6 + if origin != get_domain_from_id(pdu.event_id): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893). + if not ( + pdu.type == 'm.room.member' and + pdu.content and + pdu.content.get("membership", None) == 'join' + ): + logger.info( + "Discarding PDU %s from invalid origin %s", + pdu.event_id, origin + ) + return + else: + logger.info( + "Accepting join PDU %s from %s", + pdu.event_id, origin + ) # Check signature. try: @@ -529,156 +559,11 @@ class FederationServer(FederationBase): affected=pdu.event_id, ) - state = None - - auth_chain = [] - - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] - ) - - fetch_state = False - - # Get missing pdus if necessary. - if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = yield self.handler.get_min_depth_for_context( - pdu.room_id - ) - - logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth - ) - - prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) - - if min_depth and pdu.depth < min_depth: - # This is so that we don't notify the user about this - # message, to work around the fact that some events will - # reference really really old events we really don't want to - # send to the clients. - pdu.internal_metadata.outlier = True - elif min_depth and pdu.depth > min_depth: - if get_missing and prevs - seen: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring lock for room %r to fetch %d missing events: %r...", - pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], - ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired lock for room %r to fetch %d missing events", - pdu.room_id, len(prevs - seen), - ) - - # We recalculate seen, since it may have changed. - have_seen = yield self.store.have_events(prevs) - seen = set(have_seen.keys()) - - if prevs - seen: - latest = yield self.store.get_latest_event_ids_in_room( - pdu.room_id - ) - - # We add the prev events that we have seen to the latest - # list to ensure the remote server doesn't give them to us - latest = set(latest) - latest |= seen - - logger.info( - "Missing %d events for room %r: %r...", - len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] - ) - - # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733. - # The reason is to avoid holding the linearizer lock - # whilst processing inbound /send transactions, causing - # FDs to stack up and block other inbound transactions - # which empirically can currently take up to 30 minutes. - # - # N.B. this explicitly disables retry attempts. - # - # N.B. this also increases our chances of falling back to - # fetching fresh state for the room if the missing event - # can't be found, which slightly reduces our security. - # it may also increase our DAG extremity count for the room, - # causing additional state resolution? See #1760. - # However, fetching state doesn't hold the linearizer lock - # apparently. - # - # see https://github.com/matrix-org/synapse/pull/1744 - - missing_events = yield self.get_missing_events( - origin, - pdu.room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - timeout=10000, - ) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for e in missing_events: - yield self._handle_new_pdu( - origin, - e, - get_missing=False - ) - - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] - ) - - prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) - if prevs - seen: - logger.info( - "Still missing %d events for room %r: %r...", - len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] - ) - fetch_state = True - - if fetch_state: - # We need to get the state at this event, since we haven't - # processed all the prev events. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - try: - state, auth_chain = yield self.get_state_for_room( - origin, pdu.room_id, pdu.event_id, - ) - except: - logger.exception("Failed to get state for event: %s", pdu.event_id) - - yield self.handler.on_receive_pdu( - origin, - pdu, - state=state, - auth_chain=auth_chain, - ) + yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def exchange_third_party_invite( self, @@ -701,3 +586,66 @@ class FederationServer(FederationBase): origin, room_id, event_dict ) defer.returnValue(ret) + + +class FederationHandlerRegistry(object): + """Allows classes to register themselves as handlers for a given EDU or + query type for incoming federation traffic. + """ + def __init__(self): + self.edu_handlers = {} + self.query_handlers = {} + + def register_edu_handler(self, edu_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation EDU of the given type. + + Args: + edu_type (str): The type of the incoming EDU to register handler for + handler (Callable[[str, dict]]): A callable invoked on incoming EDU + of the given type. The arguments are the origin server name and + the EDU contents. + """ + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (Callable[[dict], Deferred[dict]]): Invoked to handle + incoming queries of this type. The return will be yielded + on and the result used as the response to the query request. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + def on_edu(self, edu_type, origin, content): + handler = self.edu_handlers.get(edu_type) + if not handler: + logger.warn("No handler registered for EDU type %s", edu_type) + + try: + yield handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception as e: + logger.exception("Failed to handle edu %r", edu_type) + + def on_query(self, query_type, args): + handler = self.query_handlers.get(query_type) + if not handler: + logger.warn("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) + + return handler(args) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py deleted file mode 100644 index 62d865ec4b..0000000000 --- a/synapse/federation/replication.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This layer is responsible for replicating with remote home servers using -a given transport. -""" - -from .federation_client import FederationClient -from .federation_server import FederationServer - -from .persistence import TransactionActions - -import logging - - -logger = logging.getLogger(__name__) - - -class ReplicationLayer(FederationClient, FederationServer): - """This layer is responsible for replicating with remote home servers over - the given transport. I.e., does the sending and receiving of PDUs to - remote home servers. - - The layer communicates with the rest of the server via a registered - ReplicationHandler. - - In more detail, the layer: - * Receives incoming data and processes it into transactions and pdus. - * Fetches any PDUs it thinks it might have missed. - * Keeps the current state for contexts up to date by applying the - suitable conflict resolution. - * Sends outgoing pdus wrapped in transactions. - * Fills out the references to previous pdus/transactions appropriately - for outgoing data. - """ - - def __init__(self, hs, transport_layer): - self.server_name = hs.hostname - - self.keyring = hs.get_keyring() - - self.transport_layer = transport_layer - - self.federation_client = self - - self.store = hs.get_datastore() - - self.handler = None - self.edu_handlers = {} - self.query_handlers = {} - - self._clock = hs.get_clock() - - self.transaction_actions = TransactionActions(self.store) - - self.hs = hs - - super(ReplicationLayer, self).__init__(hs) - - def __str__(self): - return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5c9f7a86f0..0f0c687b37 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,21 +31,21 @@ Events are replicated via a separate events stream. from .units import Edu +from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure import synapse.metrics from blist import sorteddict -import ujson +from collections import namedtuple +import logging -metrics = synapse.metrics.get_metrics_for(__name__) +from six import itervalues, iteritems + +logger = logging.getLogger(__name__) -PRESENCE_TYPE = "p" -KEYED_EDU_TYPE = "k" -EDU_TYPE = "e" -FAILURE_TYPE = "f" -DEVICE_MESSAGE_TYPE = "d" +metrics = synapse.metrics.get_metrics_for(__name__) class FederationRemoteSendQueue(object): @@ -54,18 +54,20 @@ class FederationRemoteSendQueue(object): def __init__(self, hs): self.server_name = hs.hostname self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id - self.presence_map = {} - self.presence_changed = sorteddict() + self.presence_map = {} # Pending presence map user_id -> UserPresenceState + self.presence_changed = sorteddict() # Stream position -> user_id - self.keyed_edu = {} - self.keyed_edu_changed = sorteddict() + self.keyed_edu = {} # (destination, key) -> EDU + self.keyed_edu_changed = sorteddict() # stream position -> (destination, key) - self.edus = sorteddict() + self.edus = sorteddict() # stream position -> Edu - self.failures = sorteddict() + self.failures = sorteddict() # stream position -> (destination, Failure) - self.device_messages = sorteddict() + self.device_messages = sorteddict() # stream position -> destination self.pos = 1 self.pos_time = sorteddict() @@ -121,7 +123,9 @@ class FederationRemoteSendQueue(object): del self.presence_changed[key] user_ids = set( - user_id for uids in self.presence_changed.values() for _, user_id in uids + user_id + for uids in itervalues(self.presence_changed) + for user_id in uids ) to_del = [ @@ -186,37 +190,50 @@ class FederationRemoteSendQueue(object): else: self.edus[pos] = edu - def send_presence(self, destination, states): - """As per TransactionQueue""" + self.notifier.on_new_replication_data() + + def send_presence(self, states): + """As per TransactionQueue + + Args: + states (list(UserPresenceState)) + """ pos = self._next_pos() - self.presence_map.update({ - state.user_id: state - for state in states - }) + # We only want to send presence for our own users, so lets always just + # filter here just in case. + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - self.presence_changed[pos] = [ - (destination, state.user_id) for state in states - ] + self.presence_map.update({state.user_id: state for state in local_states}) + self.presence_changed[pos] = [state.user_id for state in local_states] + + self.notifier.on_new_replication_data() def send_failure(self, failure, destination): """As per TransactionQueue""" pos = self._next_pos() self.failures[pos] = (destination, str(failure)) + self.notifier.on_new_replication_data() def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() self.device_messages[pos] = destination + self.notifier.on_new_replication_data() def get_current_token(self): return self.pos - 1 - def get_replication_rows(self, token, limit, federation_ack=None): - """ + def federation_ack(self, token): + self._clear_queue_before_pos(token) + + def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + """Get rows to be sent over federation between the two tokens + Args: - token (int) + from_token (int) + to_token(int) limit (int) federation_ack (int): Optional. The position where the worker is explicitly acknowledged it has handled. Allows us to drop @@ -225,9 +242,11 @@ class FederationRemoteSendQueue(object): # TODO: Handle limit. # To handle restarts where we wrap around - if token > self.pos: - token = -1 + if from_token > self.pos: + from_token = -1 + # list of tuple(int, BaseFederationRow), where the first is the position + # of the federation stream. rows = [] # There should be only one reader, so lets delete everything its @@ -237,62 +256,295 @@ class FederationRemoteSendQueue(object): # Fetch changed presence keys = self.presence_changed.keys() - i = keys.bisect_right(token) - dest_user_ids = set( - (pos, dest_user_id) - for pos in keys[i:] - for dest_user_id in self.presence_changed[pos] - ) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + dest_user_ids = [ + (pos, user_id) + for pos in keys[i:j] + for user_id in self.presence_changed[pos] + ] - for (key, (dest, user_id)) in dest_user_ids: - rows.append((key, PRESENCE_TYPE, ujson.dumps({ - "destination": dest, - "state": self.presence_map[user_id].as_dict(), - }))) + for (key, user_id) in dest_user_ids: + rows.append((key, PresenceRow( + state=self.presence_map[user_id], + ))) # Fetch changes keyed edus keys = self.keyed_edu_changed.keys() - i = keys.bisect_right(token) - keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) - - for (pos, (destination, edu_key)) in keyed_edus: - rows.append( - (pos, KEYED_EDU_TYPE, ujson.dumps({ - "key": edu_key, - "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), - })) - ) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + # We purposefully clobber based on the key here, python dict comprehensions + # always use the last value, so this will correctly point to the last + # stream position. + keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]} + + for ((destination, edu_key), pos) in iteritems(keyed_edus): + rows.append((pos, KeyedEduRow( + key=edu_key, + edu=self.keyed_edu[(destination, edu_key)], + ))) # Fetch changed edus keys = self.edus.keys() - i = keys.bisect_right(token) - edus = set((k, self.edus[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + edus = ((k, self.edus[k]) for k in keys[i:j]) for (pos, edu) in edus: - rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + rows.append((pos, EduRow(edu))) # Fetch changed failures keys = self.failures.keys() - i = keys.bisect_right(token) - failures = set((k, self.failures[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + failures = ((k, self.failures[k]) for k in keys[i:j]) for (pos, (destination, failure)) in failures: - rows.append((pos, FAILURE_TYPE, ujson.dumps({ - "destination": destination, - "failure": failure, - }))) + rows.append((pos, FailureRow( + destination=destination, + failure=failure, + ))) # Fetch changed device messages keys = self.device_messages.keys() - i = keys.bisect_right(token) - device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + device_messages = {self.device_messages[k]: k for k in keys[i:j]} - for (pos, destination) in device_messages: - rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ - "destination": destination, - }))) + for (destination, pos) in iteritems(device_messages): + rows.append((pos, DeviceRow( + destination=destination, + ))) # Sort rows based on pos rows.sort() - return rows + return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + + +class BaseFederationRow(object): + """Base class for rows to be sent in the federation stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @staticmethod + def from_data(data): + """Parse the data from the federation stream into a row. + + Args: + data: The value of ``data`` from FederationStreamRow.data, type + depends on the type of stream + """ + raise NotImplementedError() + + def to_data(self): + """Serialize this row to be sent over the federation stream. + + Returns: + The value to be sent in FederationStreamRow.data. The type depends + on the type of stream. + """ + raise NotImplementedError() + + def add_to_buffer(self, buff): + """Add this row to the appropriate field in the buffer ready for this + to be sent over federation. + + We use a buffer so that we can batch up events that have come in at + the same time and send them all at once. + + Args: + buff (BufferedToSend) + """ + raise NotImplementedError() + + +class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( + "state", # UserPresenceState +))): + TypeId = "p" + + @staticmethod + def from_data(data): + return PresenceRow( + state=UserPresenceState.from_dict(data) + ) + + def to_data(self): + return self.state.as_dict() + + def add_to_buffer(self, buff): + buff.presence.append(self.state) + + +class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( + "key", # tuple(str) - the edu key passed to send_edu + "edu", # Edu +))): + """Streams EDUs that have an associated key that is ued to clobber. For example, + typing EDUs clobber based on room_id. + """ + + TypeId = "k" + + @staticmethod + def from_data(data): + return KeyedEduRow( + key=tuple(data["key"]), + edu=Edu(**data["edu"]), + ) + + def to_data(self): + return { + "key": self.key, + "edu": self.edu.get_internal_dict(), + } + + def add_to_buffer(self, buff): + buff.keyed_edus.setdefault( + self.edu.destination, {} + )[self.key] = self.edu + + +class EduRow(BaseFederationRow, namedtuple("EduRow", ( + "edu", # Edu +))): + """Streams EDUs that don't have keys. See KeyedEduRow + """ + TypeId = "e" + + @staticmethod + def from_data(data): + return EduRow(Edu(**data)) + + def to_data(self): + return self.edu.get_internal_dict() + + def add_to_buffer(self, buff): + buff.edus.setdefault(self.edu.destination, []).append(self.edu) + + +class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( + "destination", # str + "failure", +))): + """Streams failures to a remote server. Failures are issued when there was + something wrong with a transaction the remote sent us, e.g. it included + an event that was invalid. + """ + + TypeId = "f" + + @staticmethod + def from_data(data): + return FailureRow( + destination=data["destination"], + failure=data["failure"], + ) + + def to_data(self): + return { + "destination": self.destination, + "failure": self.failure, + } + + def add_to_buffer(self, buff): + buff.failures.setdefault(self.destination, []).append(self.failure) + + +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( + "destination", # str +))): + """Streams the fact that either a) there is pending to device messages for + users on the remote, or b) a local users device has changed and needs to + be sent to the remote. + """ + TypeId = "d" + + @staticmethod + def from_data(data): + return DeviceRow(destination=data["destination"]) + + def to_data(self): + return {"destination": self.destination} + + def add_to_buffer(self, buff): + buff.device_destinations.add(self.destination) + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + PresenceRow, + KeyedEduRow, + EduRow, + FailureRow, + DeviceRow, + ) +} + + +ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( + "presence", # list(UserPresenceState) + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "failures", # dict of destination -> [failures] + "device_destinations", # set of destinations +)) + + +def process_rows_for_federation(transaction_queue, rows): + """Parse a list of rows from the federation stream and put them in the + transaction queue ready for sending to the relevant homeservers. + + Args: + transaction_queue (TransactionQueue) + rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + """ + + # The federation stream contains a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + + buff = ParsedFederationStreamData( + presence=[], + keyed_edus={}, + edus={}, + failures={}, + device_destinations=set(), + ) + + # Parse the rows in the stream and add to the buffer + for row in rows: + if row.type not in TypeToRow: + logger.error("Unrecognized federation row type %r", row.type) + continue + + RowType = TypeToRow[row.type] + parsed_row = RowType.from_data(row.data) + parsed_row.add_to_buffer(buff) + + if buff.presence: + transaction_queue.send_presence(buff.presence) + + for destination, edu_map in iteritems(buff.keyed_edus): + for key, edu in edu_map.items(): + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in iteritems(buff.edus): + for edu in edu_list: + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in iteritems(buff.failures): + for failure in failure_list: + transaction_queue.send_failure(destination, failure) + + for destination in buff.device_destinations: + transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index bb3d9258a6..ded2b1871a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -12,22 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import datetime from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu -from synapse.api.errors import HttpResponseException +from synapse.api.errors import HttpResponseException, FederationDeniedError +from synapse.util import logcontext, PreserveLoggingContext from synapse.util.async import run_on_reactor -from synapse.util.logcontext import preserve_context_over_fn -from synapse.util.retryutils import ( - get_retry_limiter, NotRetryingDestination, -) +from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func -from synapse.types import get_domain_from_id -from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.presence import format_user_presence_state, get_interested_remotes import synapse.metrics import logging @@ -43,6 +40,10 @@ sent_pdus_destination_dist = client_metrics.register_distribution( ) sent_edus_counter = client_metrics.register_counter("sent_edus") +sent_transactions_counter = client_metrics.register_counter("sent_transactions") + +events_processed_counter = client_metrics.register_counter("events_processed") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -79,8 +80,18 @@ class TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = edus = {} - # Presence needs to be separate as we send single aggragate EDUs + # Map of user_id -> UserPresenceState for all the pending presence + # to be sent out by user_id. Entries here get processed and put in + # pending_presence_by_dest + self.pending_presence = {} + + # Map of destination -> user_id -> UserPresenceState of pending presence + # to be sent to each destinations self.pending_presence_by_dest = presence = {} + + # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered + # based on their key (e.g. typing events by room_id) + # Map of destination -> (edu_type, key) -> Edu self.pending_edus_keyed_by_dest = edus_keyed = {} metrics.register_callback( @@ -99,7 +110,12 @@ class TransactionQueue(object): # destination -> list of tuple(failure, deferred) self.pending_failures_by_dest = {} + # destination -> stream_id of last successfully sent to-device message. + # NB: may be a long or an int. self.last_device_stream_id_by_dest = {} + + # destination -> stream_id of last successfully sent device list + # update. self.last_device_list_stream_id_by_dest = {} # HACK to get unique tx id @@ -110,6 +126,8 @@ class TransactionQueue(object): self._is_processing = False self._last_poked_id = -1 + self._processing_pending_presence = False + def can_send_to(self, destination): """Can we send messages to the given server? @@ -130,7 +148,6 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - @defer.inlineCallbacks def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -140,12 +157,19 @@ class TransactionQueue(object): if self._is_processing: return + # fire off a processing loop in the background. It's likely it will + # outlast the current request, so run it in the sentinel logcontext. + with PreserveLoggingContext(): + self._process_event_queue_loop() + + @defer.inlineCallbacks + def _process_event_queue_loop(self): try: self._is_processing = True while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=20, + last_token, self._last_poked_id, limit=100, ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -153,28 +177,35 @@ class TransactionQueue(object): if not events and next_token >= self._last_poked_id: break - for event in events: + @defer.inlineCallbacks + def handle_event(event): # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.event_id) if not is_mine and send_on_behalf_of is None: - continue - - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - users_in_room = yield self.state.get_current_user_in_room( - event.room_id, latest_event_ids=[ - prev_id for prev_id, _ in event.prev_events - ], - ) + return + + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = yield self.state.get_current_hosts_in_room( + event.room_id, latest_event_ids=[ + prev_id for prev_id, _ in event.prev_events + ], + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return + + destinations = set(destinations) - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - ) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to @@ -185,10 +216,44 @@ class TransactionQueue(object): self._send_pdu(event, destinations) + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + events_by_room = {} + for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in events_by_room.itervalues() + ], + consumeErrors=True + )) + yield self.store.update_federation_out_pos( "events", next_token ) + if events: + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_lag.set( + now - ts, "federation_sender", + ) + synapse.metrics.event_processing_last_ts.set( + ts, "federation_sender", + ) + + events_processed_counter.inc_by(len(events)) + + synapse.metrics.event_processing_positions.set( + next_token, "federation_sender", + ) + finally: self._is_processing = False @@ -217,21 +282,75 @@ class TransactionQueue(object): (pdu, order) ) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) - def send_presence(self, destination, states): - if not self.can_send_to(destination): - return + @logcontext.preserve_fn # the caller should not yield on this + @defer.inlineCallbacks + def send_presence(self, states): + """Send the new presence states to the appropriate destinations. + + This actually queues up the presence states ready for sending and + triggers a background task to process them and send out the transactions. + + Args: + states (list(UserPresenceState)) + """ - self.pending_presence_by_dest.setdefault(destination, {}).update({ + # First we queue up the new presence by user ID, so multiple presence + # updates in quick successtion are correctly handled + # We only want to send presence for our own users, so lets always just + # filter here just in case. + self.pending_presence.update({ state.user_id: state for state in states + if self.is_mine_id(state.user_id) }) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + # We then handle the new pending presence in batches, first figuring + # out the destinations we need to send each state to and then poking it + # to attempt a new transaction. We linearize this so that we don't + # accidentally mess up the ordering and send multiple presence updates + # in the wrong order + if self._processing_pending_presence: + return + + self._processing_pending_presence = True + try: + while True: + states_map = self.pending_presence + self.pending_presence = {} + + if not states_map: + break + + yield self._process_presence_inner(states_map.values()) + except Exception: + logger.exception("Error sending presence states to servers") + finally: + self._processing_pending_presence = False + + @measure_func("txnqueue._process_presence") + @defer.inlineCallbacks + def _process_presence_inner(self, states): + """Given a list of states populate self.pending_presence_by_dest and + poke to send a new transaction to each destination + + Args: + states (list(UserPresenceState)) + """ + hosts_and_states = yield get_interested_remotes(self.store, states, self.state) + + for destinations, states in hosts_and_states: + for destination in destinations: + if not self.can_send_to(destination): + continue + + self.pending_presence_by_dest.setdefault( + destination, {} + ).update({ + state.user_id: state for state in states + }) + + self._attempt_new_transaction(destination) def send_edu(self, destination, edu_type, content, key=None): edu = Edu( @@ -253,9 +372,7 @@ class TransactionQueue(object): else: self.pending_edus_by_dest.setdefault(destination, []).append(edu) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": @@ -268,9 +385,7 @@ class TransactionQueue(object): destination, [] ).append(failure) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": @@ -279,15 +394,24 @@ class TransactionQueue(object): if not self.can_send_to(destination): return - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def get_current_token(self): return 0 - @defer.inlineCallbacks def _attempt_new_transaction(self, destination): + """Try to start a new transaction to this destination + + If there is already a transaction in progress to this destination, + returns immediately. Otherwise kicks off the process of sending a + transaction in the background. + + Args: + destination (str): + + Returns: + None + """ # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending @@ -300,12 +424,46 @@ class TransactionQueue(object): ) return + logger.debug("TX [%s] Starting transaction loop", destination) + + # Drop the logcontext before starting the transaction. It doesn't + # really make sense to log all the outbound transactions against + # whatever path led us to this point: that's pretty arbitrary really. + # + # (this also means we can fire off _perform_transaction without + # yielding) + with logcontext.PreserveLoggingContext(): + self._transaction_transmission_loop(destination) + + @defer.inlineCallbacks + def _transaction_transmission_loop(self, destination): + pending_pdus = [] try: self.pending_transactions[destination] = 1 + # This will throw if we wouldn't retry. We do this here so we fail + # quickly, but we will later check this again in the http client, + # hence why we throw the result away. + yield get_retry_limiter(destination, self.clock, self.store) + + # XXX: what's this for? yield run_on_reactor() + pending_pdus = [] while True: + device_message_edus, device_stream_id, dev_list_id = ( + yield self._get_new_device_messages(destination) + ) + + # BEGIN CRITICAL SECTION + # + # In order to avoid a race condition, we need to make sure that + # the following code (from popping the queues up to the point + # where we decide if we actually have any pending messages) is + # atomic - otherwise new PDUs or EDUs might arrive in the + # meantime, but not get sent because we hold the + # pending_transactions flag. + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_presence = self.pending_presence_by_dest.pop(destination, {}) @@ -315,17 +473,6 @@ class TransactionQueue(object): self.pending_edus_keyed_by_dest.pop(destination, {}).values() ) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - backoff_on_404=True, # If we get a 404 the other side has gone - ) - - device_message_edus, device_stream_id, dev_list_id = ( - yield self._get_new_device_messages(destination) - ) - pending_edus.extend(device_message_edus) if pending_presence: pending_edus.append( @@ -355,11 +502,13 @@ class TransactionQueue(object): ) return + # END CRITICAL SECTION + success = yield self._send_new_transaction( destination, pending_pdus, pending_edus, pending_failures, - limiter=limiter, ) if success: + sent_transactions_counter.inc() # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if device_message_edus: @@ -375,12 +524,26 @@ class TransactionQueue(object): self.last_device_list_stream_id_by_dest[destination] = dev_list_id else: break - except NotRetryingDestination: + except NotRetryingDestination as e: logger.debug( - "TX [%s] not ready for retry yet - " + "TX [%s] not ready for retry yet (next retry at %s) - " "dropping transaction for now", destination, + datetime.datetime.fromtimestamp( + (e.retry_last_ts + e.retry_interval) / 1000.0 + ), ) + except FederationDeniedError as e: + logger.info(e) + except Exception as e: + logger.warn( + "TX [%s] Failed to send transaction: %s", + destination, + e, + ) + for p, _ in pending_pdus: + logger.info("Failed to send event %s to %s", p.event_id, + destination) finally: # We want to be *very* sure we delete this after we stop processing self.pending_transactions.pop(destination, None) @@ -420,7 +583,7 @@ class TransactionQueue(object): @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures, limiter): + pending_failures): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -430,132 +593,104 @@ class TransactionQueue(object): success = True - try: - logger.debug("TX [%s] _attempt_new_transaction", destination) + logger.debug("TX [%s] _attempt_new_transaction", destination) - txn_id = str(self._next_txn_id) + txn_id = str(self._next_txn_id) - logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, txn_id, - len(pdus), - len(edus), - len(failures) - ) + logger.debug( + "TX [%s] {%s} Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, txn_id, + len(pdus), + len(edus), + len(failures) + ) - logger.debug("TX [%s] Persisting transaction...", destination) + logger.debug("TX [%s] Persisting transaction...", destination) - transaction = Transaction.create_new( - origin_server_ts=int(self.clock.time_msec()), - transaction_id=txn_id, - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) + transaction = Transaction.create_new( + origin_server_ts=int(self.clock.time_msec()), + transaction_id=txn_id, + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) - self._next_txn_id += 1 + self._next_txn_id += 1 - yield self.transaction_actions.prepare_to_send(transaction) + yield self.transaction_actions.prepare_to_send(transaction) - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", - destination, txn_id, - transaction.transaction_id, - len(pdus), - len(edus), - len(failures), - ) + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] {%s} Sending transaction [%s]," + " (PDUs: %d, EDUs: %d, failures: %d)", + destination, txn_id, + transaction.transaction_id, + len(pdus), + len(edus), + len(failures), + ) - with limiter: - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self.clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - try: - response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) - except HttpResponseException as e: - code = e.code - response = e.response - - if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self.clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + try: + response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + code = 200 + + if response: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "Transaction returned error for %s: %s", + e_id, r, ) - raise e + except HttpResponseException as e: + code = e.code + response = e.response + if e.code in (401, 404, 429) or 500 <= e.code: logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code ) + raise e - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - - yield self.transaction_actions.delivered( - transaction, code, response - ) + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) - logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) - if code != 200: - for p in pdus: - logger.info( - "Failed to send event %s to %s", p.event_id, destination - ) - success = False - except RuntimeError as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) + yield self.transaction_actions.delivered( + transaction, code, response + ) - success = False + logger.debug("TX [%s] Marked as delivered", destination) + if code != 200: for p in pdus: - logger.info("Failed to send event %s to %s", p.event_id, destination) - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - + logger.info( + "Failed to send event %s to %s", p.event_id, destination + ) success = False - for p in pdus: - logger.info("Failed to send event %s to %s", p.event_id, destination) - defer.returnValue(success) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index f49e8a2cc4..6db8efa6dd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function import logging +import urllib logger = logging.getLogger(__name__) @@ -49,7 +51,7 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state/%s/" % room_id + path = _create_path(PREFIX, "/state/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -71,7 +73,7 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state_ids/%s/" % room_id + path = _create_path(PREFIX, "/state_ids/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -93,7 +95,7 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = PREFIX + "/event/%s/" % (event_id, ) + path = _create_path(PREFIX, "/event/%s/", event_id) return self.client.get_json(destination, path=path, timeout=timeout) @log_function @@ -119,7 +121,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = PREFIX + "/backfill/%s/" % (room_id,) + path = _create_path(PREFIX, "/backfill/%s/", room_id) args = { "v": event_tuples, @@ -157,12 +159,15 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() + path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id) + response = yield self.client.put_json( transaction.destination, - path=PREFIX + "/send/%s/" % transaction.transaction_id, + path=path, data=json_data, json_data_callback=json_data_callback, long_retries=True, + backoff_on_404=True, # If we get a 404 the other side has gone ) logger.debug( @@ -174,8 +179,9 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail): - path = PREFIX + "/query/%s" % query_type + def make_query(self, destination, query_type, args, retry_on_dns_fail, + ignore_backoff=False): + path = _create_path(PREFIX, "/query/%s", query_type) content = yield self.client.get_json( destination=destination, @@ -183,6 +189,7 @@ class TransportLayerClient(object): args=args, retry_on_dns_fail=retry_on_dns_fail, timeout=10000, + ignore_backoff=ignore_backoff, ) defer.returnValue(content) @@ -190,19 +197,54 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def make_membership_event(self, destination, room_id, user_id, membership): + """Asks a remote server to build and sign us a membership event + + Note that this does not append any events to any graphs. + + Args: + destination (str): address of remote homeserver + room_id (str): room to join/leave + user_id (str): user to be joined/left + membership (str): one of join/leave + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body (ie, the new event). + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if the remote destination + is not in our federation whitelist + """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) + path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id) + + ignore_backoff = False + retry_on_dns_fail = False + + if membership == Membership.LEAVE: + # we particularly want to do our best to send leave events. The + # problem is that if it fails, we won't retry it later, so if the + # remote server was just having a momentary blip, the room will be + # out of sync. + ignore_backoff = True + retry_on_dns_fail = True content = yield self.client.get_json( destination=destination, path=path, - retry_on_dns_fail=False, + retry_on_dns_fail=retry_on_dns_fail, timeout=20000, + ignore_backoff=ignore_backoff, ) defer.returnValue(content) @@ -210,7 +252,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_join(self, destination, room_id, event_id, content): - path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -223,12 +265,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): - path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, path=path, data=content, + + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, ) defer.returnValue(response) @@ -236,12 +284,13 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): - path = PREFIX + "/invite/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, path=path, data=content, + ignore_backoff=True, ) defer.returnValue(response) @@ -269,6 +318,7 @@ class TransportLayerClient(object): destination=remote_server, path=path, args=args, + ignore_backoff=True, ) defer.returnValue(response) @@ -276,7 +326,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) + path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,) response = yield self.client.put_json( destination=destination, @@ -289,7 +339,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id) content = yield self.client.get_json( destination=destination, @@ -301,7 +351,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_query_auth(self, destination, room_id, event_id, content): - path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( destination=destination, @@ -363,7 +413,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/user/devices/" + user_id + path = _create_path(PREFIX, "/user/devices/%s", user_id) content = yield self.client.get_json( destination=destination, @@ -413,7 +463,7 @@ class TransportLayerClient(object): @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth, timeout): - path = PREFIX + "/get_missing_events/%s" % (room_id,) + path = _create_path(PREFIX, "/get_missing_events/%s", room_id,) content = yield self.client.post_json( destination=destination, @@ -428,3 +478,475 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + @log_function + def get_group_profile(self, destination, group_id, requester_user_id): + """Get a group profile + """ + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_profile(self, destination, group_id, requester_user_id, content): + """Update a remote group profile + + Args: + destination (str) + group_id (str) + requester_user_id (str) + content (dict): The new profile of the group + """ + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def get_group_summary(self, destination, group_id, requester_user_id): + """Get a group summary + """ + path = _create_path(PREFIX, "/groups/%s/summary", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_rooms_in_group(self, destination, group_id, requester_user_id): + """Get all rooms in a group + """ + path = _create_path(PREFIX, "/groups/%s/rooms", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def add_room_to_group(self, destination, group_id, requester_user_id, room_id, + content): + """Add a room to a group + """ + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def update_room_in_group(self, destination, group_id, requester_user_id, room_id, + config_key, content): + """Update room in group + """ + path = _create_path( + PREFIX, "/groups/%s/room/%s/config/%s", + group_id, room_id, config_key, + ) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + """Remove a room from a group + """ + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_users_in_group(self, destination, group_id, requester_user_id): + """Get users in a group + """ + path = _create_path(PREFIX, "/groups/%s/users", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_invited_users_in_group(self, destination, group_id, requester_user_id): + """Get users that have been invited to a group + """ + path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def accept_group_invite(self, destination, group_id, user_id, content): + """Accept a group invite + """ + path = _create_path( + PREFIX, "/groups/%s/users/%s/accept_invite", + group_id, user_id, + ) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def join_group(self, destination, group_id, user_id, content): + """Attempts to join a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + """Invite a user to a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group_notification(self, destination, group_id, user_id, content): + """Sent by group server to inform a user's server that they have been + invited. + """ + + path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group(self, destination, group_id, requester_user_id, + user_id, content): + """Remove a user fron a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group_notification(self, destination, group_id, user_id, + content): + """Sent by group server to inform a user's server that they have been + kicked from the group. + """ + + path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def renew_group_attestation(self, destination, group_id, user_id, content): + """Sent by either a group server or a user's server to periodically update + the attestations + """ + + path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_room(self, destination, group_id, user_id, room_id, + category_id, content): + """Update a room entry in a group summary + """ + if category_id: + path = _create_path( + PREFIX, "/groups/%s/summary/categories/%s/rooms/%s", + group_id, category_id, room_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_room(self, destination, group_id, user_id, room_id, + category_id): + """Delete a room entry in a group summary + """ + if category_id: + path = _create_path( + PREFIX + "/groups/%s/summary/categories/%s/rooms/%s", + group_id, category_id, room_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_categories(self, destination, group_id, requester_user_id): + """Get all categories in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_category(self, destination, group_id, requester_user_id, category_id): + """Get category info in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_category(self, destination, group_id, requester_user_id, category_id, + content): + """Update a category in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_category(self, destination, group_id, requester_user_id, + category_id): + """Delete a category in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_roles(self, destination, group_id, requester_user_id): + """Get all roles in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_role(self, destination, group_id, requester_user_id, role_id): + """Get a roles info + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_role(self, destination, group_id, requester_user_id, role_id, + content): + """Update a role in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_role(self, destination, group_id, requester_user_id, role_id): + """Delete a role in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id, content): + """Update a users entry in a group + """ + if role_id: + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", + group_id, role_id, user_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def set_group_join_policy(self, destination, group_id, requester_user_id, + content): + """Sets the join policy for a group + """ + path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,) + + return self.client.put_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id): + """Delete a users entry in a group + """ + if role_id: + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", + group_id, role_id, user_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def bulk_get_publicised_groups(self, destination, user_ids): + """Get the groups a list of users are publicising + """ + + path = PREFIX + "/get_groups_publicised" + + content = {"user_ids": user_ids} + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + +def _create_path(prefix, path, *args): + """Creates a path from the prefix, path template and args. Ensures that + all args are url encoded. + + Example: + + _create_path(PREFIX, "/event/%s/", event_id) + + Args: + prefix (str) + path (str): String template for the path + args: ([str]): Args to insert into path. Each arg will be url encoded + + Returns: + str + """ + return prefix + path % tuple(urllib.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c840da834c..19d09f5422 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +17,7 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, SynapseError, FederationDeniedError from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, @@ -24,7 +25,8 @@ from synapse.http.servlet import ( ) from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string -from synapse.types import ThirdPartyInstanceID +from synapse.util.logcontext import run_in_background +from synapse.types import ThirdPartyInstanceID, get_domain_from_id import functools import logging @@ -79,6 +81,8 @@ class Authenticator(object): def __init__(self, hs): self.keyring = hs.get_keyring() self.server_name = hs.hostname + self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks @@ -110,7 +114,7 @@ class Authenticator(object): key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) return (origin, key, sig) - except: + except Exception: raise AuthenticationError( 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -128,6 +132,12 @@ class Authenticator(object): json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig + if ( + self.federation_domain_whitelist is not None and + origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + if not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, @@ -138,18 +148,30 @@ class Authenticator(object): logger.info("Request from %s", origin) request.authenticated_entity = origin + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = yield self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings["retry_last_ts"]: + run_in_background(self._reset_retry_timings, origin) + defer.returnValue(origin) + @defer.inlineCallbacks + def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + yield self.store.set_destination_retry_timings(origin, 0, 0) + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + class BaseFederationServlet(object): REQUIRE_AUTH = True - def __init__(self, handler, authenticator, ratelimiter, server_name, - room_list_handler): + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator self.ratelimiter = ratelimiter - self.room_list_handler = room_list_handler def _wrap(self, func): authenticator = self.authenticator @@ -170,7 +192,7 @@ class BaseFederationServlet(object): if self.REQUIRE_AUTH: logger.exception("authenticate_request failed") raise - except: + except Exception: logger.exception("authenticate_request failed") raise @@ -263,7 +285,7 @@ class FederationSendServlet(BaseFederationServlet): code, response = yield self.handler.on_incoming_transaction( transaction_data ) - except: + except Exception: logger.exception("on_incoming_transaction failed") raise @@ -581,7 +603,7 @@ class PublicRoomList(BaseFederationServlet): else: network_tuple = ThirdPartyInstanceID(None, None) - data = yield self.room_list_handler.get_local_public_room_list( + data = yield self.handler.get_local_public_room_list( limit, since_token, network_tuple=network_tuple ) @@ -602,7 +624,550 @@ class FederationVersionServlet(BaseFederationServlet): })) -SERVLET_CLASSES = ( +class FederationGroupsProfileServlet(BaseFederationServlet): + """Get/set the basic profile of a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/profile$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_profile( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryServlet(BaseFederationServlet): + PATH = "/groups/(?P<group_id>[^/]*)/summary$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_summary( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRoomsServlet(BaseFederationServlet): + """Get the rooms in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/rooms$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_rooms_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsServlet(BaseFederationServlet): + """Add/remove room from group + """ + PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): + """Update room config in group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id, config_key): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class FederationGroupsUsersServlet(BaseFederationServlet): + """Get the users in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInvitedUsersServlet(BaseFederationServlet): + """Get the users that have been invited to a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/invited_users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInviteServlet(BaseFederationServlet): + """Ask a group server to invite someone to the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAcceptInviteServlet(BaseFederationServlet): + """Accept an invitation from the group server + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.accept_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsJoinServlet(BaseFederationServlet): + """Attempt to join a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.join_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveUserServlet(BaseFederationServlet): + """Leave or kick a user from the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsLocalInviteServlet(BaseFederationServlet): + """A group server has invited a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + new_content = yield self.handler.on_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): + """A group server has removed a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.user_removed_from_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation + """ + PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + # We don't need to check auth here as we check the attestation signatures + + new_content = yield self.handler.on_renew_attestation( + group_id, user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoriesServlet(BaseFederationServlet): + """Get all categories for a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoryServlet(BaseFederationServlet): + """Add/remove/get a category in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_category( + group_id, requester_user_id, category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRolesServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRoleServlet(BaseFederationServlet): + """Add/remove/get a role in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_role( + group_id, requester_user_id, role_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_role( + group_id, requester_user_id, role_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_role( + group_id, requester_user_id, role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsSummaryUsersServlet(BaseFederationServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/get_groups_publicised$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query): + resp = yield self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): + """Sets whether a group is joinable without an invite or knock + """ + PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + +FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationPullServlet, FederationEventServlet, @@ -625,17 +1190,85 @@ SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, OpenIdUserInfo, - PublicRoomList, FederationVersionServlet, ) +ROOM_LIST_CLASSES = ( + PublicRoomList, +) + +GROUP_SERVER_SERVLET_CLASSES = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) + + +GROUP_LOCAL_SERVLET_CLASSES = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) + + +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) + + def register_servlets(hs, resource, authenticator, ratelimiter): - for servletclass in SERVLET_CLASSES: + for servletclass in FEDERATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_federation_server(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in ROOM_LIST_CLASSES: + servletclass( + handler=hs.get_room_list_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_SERVER_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_server_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_LOCAL_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_local_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: servletclass( - handler=hs.get_replication_layer(), + handler=hs.get_groups_attestation_renewer(), authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, - room_list_handler=hs.get_room_list_handler(), ).register(resource) diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/groups/__init__.py diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py new file mode 100644 index 0000000000..6f11fa374b --- /dev/null +++ b/synapse/groups/attestations.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attestations ensure that users and groups can't lie about their memberships. + +When a user joins a group the HS and GS swap attestations, which allow them +both to independently prove to third parties their membership.These +attestations have a validity period so need to be periodically renewed. + +If a user leaves (or gets kicked out of) a group, either side can still use +their attestation to "prove" their membership, until the attestation expires. +Therefore attestations shouldn't be relied on to prove membership in important +cases, but can for less important situtations, e.g. showing a users membership +of groups on their profile, showing flairs, etc.abs + +An attestsation is a signed blob of json that looks like: + + { + "user_id": "@foo:a.example.com", + "group_id": "+bar:b.example.com", + "valid_until_ms": 1507994728530, + "signatures":{"matrix.org":{"ed25519:auto":"..."}} + } +""" + +import logging +import random + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id +from synapse.util.logcontext import run_in_background + +from signedjson.sign import sign_json + + +logger = logging.getLogger(__name__) + + +# Default validity duration for new attestations we create +DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 + +# We add some jitter to the validity duration of attestations so that if we +# add lots of users at once we don't need to renew them all at once. +# The jitter is a multiplier picked randomly between the first and second number +DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) + +# Start trying to update our attestations when they come this close to expiring +UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 + + +class GroupAttestationSigning(object): + """Creates and verifies group attestations. + """ + def __init__(self, hs): + self.keyring = hs.get_keyring() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.signing_key = hs.config.signing_key[0] + + @defer.inlineCallbacks + def verify_attestation(self, attestation, group_id, user_id, server_name=None): + """Verifies that the given attestation matches the given parameters. + + An optional server_name can be supplied to explicitly set which server's + signature is expected. Otherwise assumes that either the group_id or user_id + is local and uses the other's server as the one to check. + """ + + if not server_name: + if get_domain_from_id(group_id) == self.server_name: + server_name = get_domain_from_id(user_id) + elif get_domain_from_id(user_id) == self.server_name: + server_name = get_domain_from_id(group_id) + else: + raise Exception("Expected either group_id or user_id to be local") + + if user_id != attestation["user_id"]: + raise SynapseError(400, "Attestation has incorrect user_id") + + if group_id != attestation["group_id"]: + raise SynapseError(400, "Attestation has incorrect group_id") + valid_until_ms = attestation["valid_until_ms"] + + # TODO: We also want to check that *new* attestations that people give + # us to store are valid for at least a little while. + if valid_until_ms < self.clock.time_msec(): + raise SynapseError(400, "Attestation expired") + + yield self.keyring.verify_json_for_server(server_name, attestation) + + def create_attestation(self, group_id, user_id): + """Create an attestation for the group_id and user_id with default + validity length. + """ + validity_period = DEFAULT_ATTESTATION_LENGTH_MS + validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + valid_until_ms = int(self.clock.time_msec() + validity_period) + + return sign_json({ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, self.server_name, self.signing_key) + + +class GroupAttestionRenewer(object): + """Responsible for sending and receiving attestation updates. + """ + + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.assestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.is_mine_id = hs.is_mine_id + self.attestations = hs.get_groups_attestation_signing() + + self._renew_attestations_loop = self.clock.looping_call( + self._renew_attestations, 30 * 60 * 1000, + ) + + @defer.inlineCallbacks + def on_renew_attestation(self, group_id, user_id, content): + """When a remote updates an attestation + """ + attestation = content["attestation"] + + if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): + raise SynapseError(400, "Neither user not group are on this server") + + yield self.attestations.verify_attestation( + attestation, + user_id=user_id, + group_id=group_id, + ) + + yield self.store.update_remote_attestion(group_id, user_id, attestation) + + defer.returnValue({}) + + @defer.inlineCallbacks + def _renew_attestations(self): + """Called periodically to check if we need to update any of our attestations + """ + + now = self.clock.time_msec() + + rows = yield self.store.get_attestations_need_renewals( + now + UPDATE_ATTESTATION_TIME_MS + ) + + @defer.inlineCallbacks + def _renew_attestation(group_id, user_id): + try: + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): + destination = get_domain_from_id(user_id) + else: + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) + + yield self.transport_client.renew_group_attestation( + destination, group_id, user_id, + content={"attestation": attestation}, + ) + + yield self.store.update_attestation_renewal( + group_id, user_id, attestation + ) + except Exception: + logger.exception("Error renewing attestation of %r in %r", + user_id, group_id) + + for row in rows: + group_id = row["group_id"] + user_id = row["user_id"] + + run_in_background(_renew_attestation, group_id, user_id) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py new file mode 100644 index 0000000000..2d95b04e0c --- /dev/null +++ b/synapse/groups/groups_server.py @@ -0,0 +1,950 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.types import GroupID, RoomID, UserID, get_domain_from_id +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +# TODO: Allow users to "knock" or simpkly join depending on rules +# TODO: Federation admin APIs +# TODO: is_priveged flag to users and is_public to users and rooms +# TODO: Audit log for admins (profile updates, membership changes, users who tried +# to join but were rejected, etc) +# TODO: Flairs + + +class GroupsServerHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.attestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + @defer.inlineCallbacks + def check_group_is_ours(self, group_id, requester_user_id, + and_exists=False, and_is_admin=None): + """Check that the group is ours, and optionally if it exists. + + If group does exist then return group. + + Args: + group_id (str) + and_exists (bool): whether to also check if group exists + and_is_admin (str): whether to also check if given str is a user_id + that is an admin + """ + if not self.is_mine_id(group_id): + raise SynapseError(400, "Group not on this server") + + group = yield self.store.get_group(group_id) + if and_exists and not group: + raise SynapseError(404, "Unknown group") + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + + if and_is_admin: + is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + defer.returnValue(group) + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the summary for a group as seen by requester_user_id. + + The group summary consists of the profile of the room, and a curated + list of users and rooms. These list *may* be organised by role/category. + The roles/categories are ordered, and so are the users/rooms within them. + + A user/room may appear in multiple roles/categories. + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + profile = yield self.get_group_profile(group_id, requester_user_id) + + users, roles = yield self.store.get_users_for_summary_by_role( + group_id, include_private=is_user_in_group, + ) + + # TODO: Add profiles to users + + rooms, categories = yield self.store.get_rooms_for_summary_by_category( + group_id, include_private=is_user_in_group, + ) + + for room_entry in rooms: + room_id = room_entry["room_id"] + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + entry = dict(entry) # so we don't change whats cached + entry.pop("room_id", None) + + room_entry["profile"] = entry + + rooms.sort(key=lambda e: e.get("order", 0)) + + for entry in users: + user_id = entry["user_id"] + + if not self.is_mine_id(requester_user_id): + attestation = yield self.store.get_remote_attestation(group_id, user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, user_id, + ) + + user_profile = yield self.profile_handler.get_profile_from_cache(user_id) + entry.update(user_profile) + + users.sort(key=lambda e: e.get("order", 0)) + + membership_info = yield self.store.get_users_membership_info_in_group( + group_id, requester_user_id, + ) + + defer.returnValue({ + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + }) + + @defer.inlineCallbacks + def update_group_summary_room(self, group_id, requester_user_id, + room_id, category_id, content): + """Add/update a room to the group summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + RoomID.from_string(room_id) # Ensure valid room id + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_room(self, group_id, requester_user_id, + room_id, category_id): + """Remove a room from the summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_room_from_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def set_group_join_policy(self, group_id, requester_user_id, content): + """Sets the group join policy. + + Currently supported policies are: + - "invite": an invite must be received and accepted in order to join. + - "open": anyone can join. + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + join_policy = _parse_join_policy_from_contents(content) + if join_policy is None: + raise SynapseError( + 400, "No value specified for 'm.join_policy'" + ) + + yield self.store.set_group_join_policy(group_id, join_policy=join_policy) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_categories(self, group_id, requester_user_id): + """Get all categories in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = yield self.store.get_group_categories( + group_id=group_id, + ) + defer.returnValue({"categories": categories}) + + @defer.inlineCallbacks + def get_group_category(self, group_id, requester_user_id, category_id): + """Get a specific category in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_category(self, group_id, requester_user_id, category_id, content): + """Add/Update a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + profile = content.get("profile") + + yield self.store.upsert_group_category( + group_id=group_id, + category_id=category_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_category(self, group_id, requester_user_id, category_id): + """Delete a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id + ) + + yield self.store.remove_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_roles(self, group_id, requester_user_id): + """Get all roles in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = yield self.store.get_group_roles( + group_id=group_id, + ) + defer.returnValue({"roles": roles}) + + @defer.inlineCallbacks + def get_group_role(self, group_id, requester_user_id, role_id): + """Get a specific role in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_role( + group_id=group_id, + role_id=role_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_role(self, group_id, requester_user_id, role_id, content): + """Add/update a role in a group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + + profile = content.get("profile") + + yield self.store.upsert_group_role( + group_id=group_id, + role_id=role_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_role(self, group_id, requester_user_id, role_id): + """Remove role from group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_group_role( + group_id=group_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, + content): + """Add/update a users entry in the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): + """Remove a user from the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + yield self.store.remove_user_from_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_profile(self, group_id, requester_user_id): + """Get the group profile as seen by requester_user_id + """ + + yield self.check_group_is_ours(group_id, requester_user_id) + + group = yield self.store.get_group(group_id) + + if group: + cols = [ + "name", "short_description", "long_description", + "avatar_url", "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" + + defer.returnValue(group_description) + else: + raise SynapseError(404, "Unknown group") + + @defer.inlineCallbacks + def update_group_profile(self, group_id, requester_user_id, content): + """Update the group profile + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + profile = {} + for keyname in ("name", "avatar_url", "short_description", + "long_description"): + if keyname in content: + value = content[keyname] + if not isinstance(value, basestring): + raise SynapseError(400, "%r value is not a string" % (keyname,)) + profile[keyname] = value + + yield self.store.update_group_profile(group_id, profile) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + user_results = yield self.store.get_users_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = yield self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id, + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + defer.returnValue({ + "chunk": chunk, + "total_user_count_estimate": len(user_results), + }) + + @defer.inlineCallbacks + def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = yield self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = { + "user_id": user_id + } + try: + profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warn("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + defer.returnValue({ + "chunk": user_profiles, + "total_user_count_estimate": len(invited_users), + }) + + @defer.inlineCallbacks + def get_rooms_in_group(self, group_id, requester_user_id): + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + room_results = yield self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + defer.returnValue({ + "chunk": chunk, + "total_room_count_estimate": len(room_results), + }) + + @defer.inlineCallbacks + def add_room_to_group(self, group_id, requester_user_id, room_id, content): + """Add room to group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, + content): + """Update room in group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + if config_key == "m.visibility": + is_public = _parse_visibility_dict(content) + + yield self.store.update_room_in_group_visibility( + group_id, room_id, + is_public=is_public, + ) + else: + raise SynapseError(400, "Uknown config option") + + defer.returnValue({}) + + @defer.inlineCallbacks + def remove_room_from_group(self, group_id, requester_user_id, room_id): + """Remove room from group + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + yield self.store.remove_room_from_group(group_id, room_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite_to_group(self, group_id, user_id, requester_user_id, content): + """Invite user to group + """ + + group = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + # TODO: Check if user knocked + # TODO: Check if user is already invited + + content = { + "profile": { + "name": group["name"], + "avatar_url": group["avatar_url"], + }, + "inviter": requester_user_id, + } + + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + res = yield groups_local.on_invite(group_id, user_id, content) + local_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content.update({ + "attestation": local_attestation, + }) + + res = yield self.transport_client.invite_to_group_notification( + get_domain_from_id(user_id), group_id, user_id, content + ) + + user_profile = res.get("user_profile", {}) + yield self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + if res["state"] == "join": + if not self.hs.is_mine_id(user_id): + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=user_id, + group_id=group_id, + ) + else: + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, user_id, + is_admin=False, + is_public=False, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + elif res["state"] == "invite": + yield self.store.add_group_invite( + group_id, user_id, + ) + defer.returnValue({ + "state": "invite" + }) + elif res["state"] == "reject": + defer.returnValue({ + "state": "reject" + }) + else: + raise SynapseError(502, "Unknown state returned by HS") + + @defer.inlineCallbacks + def _add_user(self, group_id, user_id, content): + """Add a user to a group based on a content dict. + + See accept_invite, join_group. + """ + if not self.hs.is_mine_id(user_id): + local_attestation = self.attestations.create_attestation( + group_id, user_id, + ) + + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=user_id, + group_id=group_id, + ) + else: + local_attestation = None + remote_attestation = None + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_group( + group_id, user_id, + is_admin=False, + is_public=is_public, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + defer.returnValue(local_attestation) + + @defer.inlineCallbacks + def accept_invite(self, group_id, requester_user_id, content): + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def join_group(self, group_id, requester_user_id, content): + """User tries to join the group. + + This will error if the group requires an invite/knock to join + """ + + group_info = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True + ) + if group_info['join_policy'] != "open": + raise SynapseError(403, "Group is not publicly joinable") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def knock(self, group_id, requester_user_id, content): + """A user requests becoming a member of the group + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def accept_knock(self, group_id, requester_user_id, content): + """Accept a users knock to the room. + + Errors if the user hasn't knocked, rather than inviting them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from the group; either a user is leaving or an admin + kicked them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_kick = False + if requester_user_id != user_id: + is_admin = yield self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + is_kick = True + + yield self.store.remove_user_from_group( + group_id, user_id, + ) + + if is_kick: + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + yield groups_local.user_removed_from_group(group_id, user_id, {}) + else: + yield self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + + if not self.hs.is_mine_id(user_id): + yield self.store.maybe_delete_remote_profile_cache(user_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def create_group(self, group_id, requester_user_id, content): + group = yield self.check_group_is_ours(group_id, requester_user_id) + + logger.info("Attempting to create group with ID: %r", group_id) + + # parsing the id into a GroupID validates it. + group_id_obj = GroupID.from_string(group_id) + + if group: + raise SynapseError(400, "Group already exists") + + is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + if not is_admin: + if not self.hs.config.enable_group_creation: + raise SynapseError( + 403, "Only a server admin can create groups on this server", + ) + localpart = group_id_obj.localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" % ( + self.hs.config.group_creation_prefix, + ), + ) + + profile = content.get("profile", {}) + name = profile.get("name") + avatar_url = profile.get("avatar_url") + short_description = profile.get("short_description") + long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) + + yield self.store.create_group( + group_id, + requester_user_id, + name=name, + avatar_url=avatar_url, + short_description=short_description, + long_description=long_description, + ) + + if not self.hs.is_mine_id(requester_user_id): + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=requester_user_id, + group_id=group_id, + ) + + local_attestation = self.attestations.create_attestation( + group_id, + requester_user_id, + ) + else: + local_attestation = None + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, requester_user_id, + is_admin=True, + is_public=True, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + if not self.hs.is_mine_id(requester_user_id): + yield self.store.add_remote_profile_cache( + requester_user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + defer.returnValue({ + "group_id": group_id, + }) + + +def _parse_join_policy_from_contents(content): + """Given a content for a request, return the specified join policy or None + """ + + join_policy_dict = content.get("m.join_policy") + if join_policy_dict: + return _parse_join_policy_dict(join_policy_dict) + else: + return None + + +def _parse_join_policy_dict(join_policy_dict): + """Given a dict for the "m.join_policy" config return the join policy specified + """ + join_policy_type = join_policy_dict.get("type") + if not join_policy_type: + return "invite" + + if join_policy_type not in ("invite", "open"): + raise SynapseError( + 400, "Synapse only supports 'invite'/'open' join rule" + ) + return join_policy_type + + +def _parse_visibility_from_contents(content): + """Given a content for a request parse out whether the entity should be + public or not + """ + + visibility = content.get("m.visibility") + if visibility: + return _parse_visibility_dict(visibility) + else: + is_public = True + + return is_public + + +def _parse_visibility_dict(visibility): + """Given a dict for the "m.visibility" config return if the entity should + be public or not + """ + vis_type = visibility.get("type") + if not vis_type: + return True + + if vis_type not in ("public", "private"): + raise SynapseError( + 400, "Synapse only supports 'public'/'private' visibility" + ) + return vis_type == "public" diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 5ad408f549..8f8fd82eb0 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,10 +17,8 @@ from .register import RegistrationHandler from .room import ( RoomCreationHandler, RoomContextHandler, ) -from .room_member import RoomMemberHandler from .message import MessageHandler from .federation import FederationHandler -from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler from .identity import IdentityHandler @@ -50,9 +48,7 @@ class Handlers(object): self.registration_handler = RegistrationHandler(hs) self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) - self.room_member_handler = RoomMemberHandler(hs) self.federation_handler = FederationHandler(hs) - self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.identity_handler = IdentityHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e83adc8339..e089e66fde 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -53,7 +53,20 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() - def ratelimit(self, requester): + @defer.inlineCallbacks + def ratelimit(self, requester, update=True): + """Ratelimits requests. + + Args: + requester (Requester) + update (bool): Whether to record that a request is being processed. + Set to False when doing multiple checks for one request (e.g. + to check up front if we would reject the request), and set to + True for the last call for a given request. + + Raises: + LimitExceededError if the request should be ratelimited + """ time_now = self.clock.time() user_id = requester.user.to_string() @@ -67,10 +80,25 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + # Check if there is a per user override in the DB. + override = yield self.store.get_ratelimit_for_user(user_id) + if override: + # If overriden with a null Hz then ratelimiting has been entirely + # disabled for the user + if not override.messages_per_second: + return + + messages_per_second = override.messages_per_second + burst_count = override.burst_count + else: + messages_per_second = self.hs.config.rc_messages_per_second + burst_count = self.hs.config.rc_message_burst_count + allowed, time_allowed = self.ratelimiter.send_message( user_id, time_now, - msg_rate_hz=self.hs.config.rc_messages_per_second, - burst_count=self.hs.config.rc_message_burst_count, + msg_rate_hz=messages_per_second, + burst_count=burst_count, + update=update, ) if not allowed: raise LimitExceededError( @@ -130,7 +158,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester( target_user, is_guest=True) - handler = self.hs.get_handlers().room_member_handler + handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, target_user, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 05af54d31b..b596f098fd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -15,14 +15,21 @@ from twisted.internet import defer +import synapse from synapse.api.constants import EventTypes from synapse.util.metrics import Measure -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import ( + make_deferred_yieldable, run_in_background, +) import logging logger = logging.getLogger(__name__) +metrics = synapse.metrics.get_metrics_for(__name__) + +events_processed_counter = metrics.register_counter("events_processed") + def log_failure(failure): logger.error( @@ -70,21 +77,25 @@ class ApplicationServicesHandler(object): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: - upper_bound = self.current_max limit = 100 while True: upper_bound, events = yield self.store.get_new_events_for_appservice( - upper_bound, limit + self.current_max, limit ) if not events: break + events_by_room = {} for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + @defer.inlineCallbacks + def handle_event(event): # Gather interested services services = yield self._get_services_for_event(event) if len(services) == 0: - continue # no services need notifying + return # no services need notifying # Do we know this user exists? If not, poke the user # query API for all services which match that user regex. @@ -100,14 +111,35 @@ class ApplicationServicesHandler(object): # Fork off pushes to these services for service in services: - preserve_fn(self.scheduler.submit_event_for_as)( - service, event - ) + self.scheduler.submit_event_for_as(service, event) + + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + yield make_deferred_yieldable(defer.gatherResults([ + run_in_background(handle_room_events, evs) + for evs in events_by_room.itervalues() + ], consumeErrors=True)) yield self.store.set_appservice_last_pos(upper_bound) - if len(events) < limit: - break + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_positions.set( + upper_bound, "appservice_sender", + ) + + events_processed_counter.inc_by(len(events)) + + synapse.metrics.event_processing_lag.set( + now - ts, "appservice_sender", + ) + synapse.metrics.event_processing_last_ts.set( + ts, "appservice_sender", + ) finally: self.is_processing = False @@ -163,8 +195,11 @@ class ApplicationServicesHandler(object): def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield preserve_context_over_deferred(defer.DeferredList([ - preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + results = yield make_deferred_yieldable(defer.DeferredList([ + run_in_background( + self.appservice_api.query_3pe, + service, kind, protocol, fields, + ) for service in services ], consumeErrors=True)) @@ -225,11 +260,15 @@ class ApplicationServicesHandler(object): event based on the service regex. """ services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - yield s.is_interested(event, self.store) - ) - ] + + # we can't use a list comprehension here. Since python 3, list + # comprehensions use a generator internally. This means you can't yield + # inside of a list comprehension anymore. + interested_list = [] + for s in services: + if (yield s.is_interested(event, self.store)): + interested_list.append(s) + defer.returnValue(interested_list) def _get_services_for_user(self, user_id): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fffba34383..a5365c4fe4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer +from twisted.internet import defer, threads from ._base import BaseHandler from synapse.api.constants import LoginType +from synapse.api.errors import ( + AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError, + SynapseError, +) +from synapse.module_api import ModuleApi from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable from twisted.web.client import PartialDownloadError @@ -44,18 +50,23 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) self.checkers = { - LoginType.PASSWORD: self._check_password_auth, LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, + LoginType.MSISDN: self._check_msisdn, LoginType.DUMMY: self._check_dummy_auth, } self.bcrypt_rounds = hs.config.bcrypt_rounds - self.sessions = {} - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists + # This is not a cache per se, but a store of all current sessions that + # expire after N hours + self.sessions = ExpiringCache( + cache_name="register_sessions", + clock=hs.get_clock(), + expiry_ms=self.SESSION_EXPIRE_MS, + reset_expiry_on_get=True, ) + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -64,39 +75,120 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + # we keep this as a list despite the O(N^2) implication so that we can + # keep PASSWORD first and avoid confusing clients which pick the first + # type in the list. (NB that the spec doesn't require us to do so and + # clients which favour types that they don't understand over those that + # they do are technically broken) + login_types = [] + if self._password_enabled: + login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + for t in provider.get_supported_login_types().keys(): + if t not in login_types: + login_types.append(t) + self._supported_login_types = login_types + + @defer.inlineCallbacks + def validate_user_via_ui_auth(self, requester, request_body, clientip): + """ + Checks that the user is who they claim to be, via a UI auth. + + This is used for things like device deletion and password reset where + the user already has a valid access token, but we want to double-check + that it isn't stolen by re-authenticating them. + + Args: + requester (Requester): The user, as given by the access token + + request_body (dict): The body of the request sent by the client + + clientip (str): The IP address of the client. + + Returns: + defer.Deferred[dict]: the parameters for this request (which may + have been given only in a previous call). + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + any of the permitted login flows + + AuthError if the client has completed a login flow, and it gives + a different user to `requester` + """ + + # build a list of supported flows + flows = [ + [login_type] for login_type in self._supported_login_types + ] + + result, params, _ = yield self.check_auth( + flows, request_body, clientip, + ) + + # find the completed login type + for login_type in self._supported_login_types: + if login_type not in result: + continue + + user_id = result[login_type] + break + else: + # this can't happen + raise Exception( + "check_auth returned True but no successful login type", + ) + + # check that the UI auth matched the access token + if user_id != requester.user.to_string(): + raise AuthError(403, "Invalid auth") + + defer.returnValue(params) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). + + 'creds' contains the authenticated credentials of each stage. - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'params' contains the parameters for this request (which may + have been given only in a previous call). - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -124,11 +216,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -139,14 +228,12 @@ class AuthHandler(BaseHandler): errordict = {} if 'type' in authdict: login_type = authdict['type'] - if login_type not in self.checkers: - raise LoginError(400, "", Codes.UNRECOGNIZED) try: - result = yield self.checkers[login_type](authdict, clientip) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) - except LoginError, e: + except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new # validation token (thus sending a new email) each time it @@ -155,7 +242,7 @@ class AuthHandler(BaseHandler): # # Grandfather in the old behaviour for now to avoid # breaking old riot deployments. - raise e + raise # this step failed. Merge the error dict into the response # so that the client can have another go. @@ -172,12 +259,14 @@ class AuthHandler(BaseHandler): "Auth completed with creds: %r. Client dict has keys: %r", creds, clientdict.keys() ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -249,16 +338,37 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - def _check_password_auth(self, authdict, _): - if "user" not in authdict or "password" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + @defer.inlineCallbacks + def _check_auth_dict(self, authdict, clientip): + """Attempt to validate the auth dict provided by a client - user_id = authdict["user"] - password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID.create(user_id, self.hs.hostname).to_string() + Args: + authdict (object): auth dict provided by the client + clientip (str): IP address of the client - return self._check_password(user_id, password) + Returns: + Deferred: result of the stage verification. + + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = authdict['type'] + checker = self.checkers.get(login_type) + if checker is not None: + res = yield checker(authdict, clientip) + defer.returnValue(res) + + # build a v1-login-style dict out of the authdict and fall back to the + # v1 code + user_id = authdict.get("user") + + if user_id is None: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + (canonical_id, callback) = yield self.validate_login(user_id, authdict) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -307,31 +417,47 @@ class AuthHandler(BaseHandler): defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - @defer.inlineCallbacks def _check_email_identity(self, authdict, _): + return self._check_threepid('email', authdict) + + def _check_msisdn(self, authdict, _): + return self._check_threepid('msisdn', authdict) + + @defer.inlineCallbacks + def _check_dummy_auth(self, authdict, _): + yield run_on_reactor() + defer.returnValue(True) + + @defer.inlineCallbacks + def _check_threepid(self, medium, authdict): yield run_on_reactor() if 'threepid_creds' not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) threepid_creds = authdict['threepid_creds'] + identity_handler = self.hs.get_handlers().identity_handler - logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) + logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) threepid = yield identity_handler.threepid_from_creds(threepid_creds) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + if threepid['medium'] != medium: + raise LoginError( + 401, + "Expecting threepid of type '%s', got '%s'" % ( + medium, threepid['medium'], + ), + errcode=Codes.UNAUTHORIZED + ) + threepid['threepid_creds'] = authdict['threepid_creds'] defer.returnValue(threepid) - @defer.inlineCallbacks - def _check_dummy_auth(self, authdict, _): - yield run_on_reactor() - defer.returnValue(True) - def _get_params_recaptcha(self): return {"public_key": self.hs.config.recaptcha_public_key} @@ -371,26 +497,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -404,13 +512,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -420,9 +525,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -474,29 +581,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue((qualified_user_id, None)) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) - canonical_user_id = yield self._check_local_password(user_id, password) + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) + + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -522,7 +715,7 @@ class AuthHandler(BaseHandler): if not lookupres: defer.returnValue(None) (user_id, password_hash) = lookupres - result = self.validate_hash(password, password_hash) + result = yield self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) defer.returnValue(None) @@ -546,22 +739,65 @@ class AuthHandler(BaseHandler): raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): - password_hash = self.hash(newpassword) + def delete_access_token(self, access_token): + """Invalidate a single access token - except_access_token_id = requester.access_token_id if requester else None + Args: + access_token (str): access token to be deleted - try: - yield self.store.user_set_password_hash(user_id, password_hash) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) - raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + # delete pushers associated with this access token + if user_info["token_id"] is not None: + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + str(user_info["user"]), (user_info["token_id"], ) + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, ) - yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_id + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, token_id, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + # delete pushers associated with the access tokens + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + user_id, (token_id for _, token_id, _ in tokens_and_devices), ) @defer.inlineCallbacks @@ -599,16 +835,6 @@ class AuthHandler(BaseHandler): logger.debug("Saving session %s", session) session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - self._prune_sessions() - - def _prune_sessions(self): - for sid, sess in self.sessions.items(): - last_used = 0 - if 'last_used' in sess: - last_used = sess['last_used'] - now = self.hs.get_clock().time_msec() - if last_used < now - AuthHandler.SESSION_EXPIRE_MS: - del self.sessions[sid] def hash(self, password): """Computes a secure hash of password. @@ -617,10 +843,13 @@ class AuthHandler(BaseHandler): password (str): Password to hash. Returns: - Hashed password (str). + Deferred(str): Hashed password. """ - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - bcrypt.gensalt(self.bcrypt_rounds)) + def _do_hash(): + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + bcrypt.gensalt(self.bcrypt_rounds)) + + return make_deferred_yieldable(threads.deferToThread(_do_hash)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -630,13 +859,19 @@ class AuthHandler(BaseHandler): stored_hash (str): Expected hash value. Returns: - Whether self.hash(password) == stored_hash (bool). + Deferred(bool): Whether self.hash(password) == stored_hash. """ + + def _do_validate_hash(): + return bcrypt.checkpw( + password.encode('utf8') + self.hs.config.password_pepper, + stored_hash.encode('utf8') + ) + if stored_hash: - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - stored_hash.encode('utf8')) == stored_hash + return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) else: - return False + return defer.succeed(False) class MacaroonGeneartor(object): @@ -679,30 +914,3 @@ class MacaroonGeneartor(object): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py new file mode 100644 index 0000000000..b1d3814909 --- /dev/null +++ b/synapse/handlers/deactivate_account.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeactivateAccountHandler(BaseHandler): + """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): + super(DeactivateAccountHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + + # first delete any devices belonging to the user, which will also + # delete corresponding access tokens. + yield self._device_handler.delete_all_devices_for_user(user_id) + # then delete any remaining access tokens which weren't associated with + # a device. + yield self._auth_handler.delete_access_tokens_for_user(user_id) + + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ca7137f315..f7457a7082 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,8 +14,11 @@ # limitations under the License. from synapse.api import errors from synapse.api.constants import EventTypes +from synapse.api.errors import FederationDeniedError from synapse.util import stringutils from synapse.util.async import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.retryutils import NotRetryingDestination from synapse.util.metrics import measure_func from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer @@ -32,14 +35,17 @@ class DeviceHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() - self.federation = hs.get_replication_layer() - self._remote_edue_linearizer = Linearizer(name="remote_device_list") - self.federation.register_edu_handler( - "m.device_list_update", self._incoming_device_list_update, + self._edu_updater = DeviceListEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( + "m.device_list_update", self._edu_updater.incoming_device_list_update, ) - self.federation.register_query_handler( + federation_registry.register_query_handler( "user_devices", self.on_federation_query_user_devices, ) @@ -103,7 +109,7 @@ class DeviceHandler(BaseHandler): device_map = yield self.store.get_devices_by_user(user_id) ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id) for device_id in device_map.keys()) + user_id, device_id=None ) devices = device_map.values() @@ -130,7 +136,7 @@ class DeviceHandler(BaseHandler): except errors.StoreError: raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id),) + user_id, device_id, ) _update_device_from_client_ips(device, ips) defer.returnValue(device) @@ -149,16 +155,15 @@ class DeviceHandler(BaseHandler): try: yield self.store.delete_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: # no match pass else: raise - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( @@ -168,6 +173,57 @@ class DeviceHandler(BaseHandler): yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks + def delete_all_devices_for_user(self, user_id, except_device_id=None): + """Delete all of the user's devices + + Args: + user_id (str): + except_device_id (str|None): optional device id which should not + be deleted + + Returns: + defer.Deferred: + """ + device_map = yield self.store.get_devices_by_user(user_id) + device_ids = device_map.keys() + if except_device_id is not None: + device_ids = [d for d in device_ids if d != except_device_id] + yield self.delete_devices(user_id, device_ids) + + @defer.inlineCallbacks + def delete_devices(self, user_id, device_ids): + """ Delete several devices + + Args: + user_id (str): + device_ids (List[str]): The list of device IDs to delete + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_devices(user_id, device_ids) + except errors.StoreError as e: + if e.code == 404: + # no match + pass + else: + raise + + # Delete access tokens and e2e keys for each device. Not optimised as it is not + # considered as part of a critical path. + for device_id in device_ids: + yield self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id, + ) + yield self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + + yield self.notify_device_update(user_id, device_ids) + + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device @@ -187,7 +243,7 @@ class DeviceHandler(BaseHandler): new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() else: @@ -212,8 +268,7 @@ class DeviceHandler(BaseHandler): user_id, device_ids, list(hosts) ) - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = [r.room_id for r in rooms] + room_ids = yield self.store.get_rooms_for_user(user_id) yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, @@ -234,8 +289,9 @@ class DeviceHandler(BaseHandler): user_id (str) from_token (StreamToken) """ - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = set(r.room_id for r in rooms) + now_token = yield self.hs.get_event_sources().get_current_token() + + room_ids = yield self.store.get_rooms_for_user(user_id) # First we check if any devices have changed changed = yield self.store.get_user_whose_devices_changed( @@ -245,11 +301,30 @@ class DeviceHandler(BaseHandler): # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + member_events = yield self.store.get_membership_changes_for_user( + user_id, from_token.room_key, now_token.room_key + ) + rooms_changed.update(event.room_id for event in member_events) + stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key).stream + from_token.room_key + ).stream possibly_changed = set(changed) + possibly_left = set() for room_id in rooms_changed: + current_state_ids = yield self.store.get_current_state_ids(room_id) + + # The user may have left the room + # TODO: Check if they actually did or if we were just invited. + if room_id not in room_ids: + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_left.add(state_key) + continue + # Fetch the current state at the time. try: event_ids = yield self.store.get_forward_extremeties_for_room( @@ -260,8 +335,6 @@ class DeviceHandler(BaseHandler): # ordering: treat it the same as a new room event_ids = [] - current_state_ids = yield self.state.get_current_state_ids(room_id) - # special-case for an empty prev state: include all members # in the changed list if not event_ids: @@ -272,9 +345,25 @@ class DeviceHandler(BaseHandler): possibly_changed.add(state_key) continue + current_member_id = current_state_ids.get((EventTypes.Member, user_id)) + if not current_member_id: + continue + # mapping from event_id -> state_dict prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + # Check if we've joined the room? If so we just blindly add all the users to + # the "possibly changed" users. + for state_dict in prev_state_ids.itervalues(): + member_event = state_dict.get((EventTypes.Member, user_id), None) + if not member_event or member_event != current_member_id: + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_changed.add(state_key) + break + # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. @@ -285,94 +374,208 @@ class DeviceHandler(BaseHandler): # check if this member has changed since any of the extremities # at the stream_ordering, and add them to the list if so. - for state_dict in prev_state_ids.values(): + for state_dict in prev_state_ids.itervalues(): prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: - possibly_changed.add(state_key) + if state_key != user_id: + possibly_changed.add(state_key) break - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id - ) + if possibly_changed or possibly_left: + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) - # Take the intersection of the users whose devices may have changed - # and those that actually still share a room with the user - defer.returnValue(users_who_share_room & possibly_changed) + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + possibly_joined = possibly_changed & users_who_share_room + possibly_left = (possibly_changed | possibly_left) - users_who_share_room + else: + possibly_joined = [] + possibly_left = [] + + defer.returnValue({ + "changed": list(possibly_joined), + "left": list(possibly_left), + }) - @measure_func("_incoming_device_list_update") @defer.inlineCallbacks - def _incoming_device_list_update(self, origin, edu_content): - user_id = edu_content["user_id"] - device_id = edu_content["device_id"] - stream_id = edu_content["stream_id"] - prev_ids = edu_content.get("prev_id", []) + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + + @defer.inlineCallbacks + def user_left_room(self, user, room_id): + user_id = user.to_string() + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We no longer share rooms with this user, so we'll no longer + # receive device updates. Mark this in DB. + yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) + + +class DeviceListEduUpdater(object): + "Handles incoming device list updates from federation and updates the DB" + + def __init__(self, hs, device_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.device_handler = device_handler + + self._remote_edu_linearizer = Linearizer(name="remote_device_list") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="device_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_device_list_update(self, origin, edu_content): + """Called on incoming device list update from federation. Responsible + for parsing the EDU and adding to pending updates list. + """ + + user_id = edu_content.pop("user_id") + device_id = edu_content.pop("device_id") + stream_id = str(edu_content.pop("stream_id")) # They may come as ints + prev_ids = edu_content.pop("prev_id", []) + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning("Got device list update edu for %r from %r", user_id, origin) return - rooms = yield self.store.get_rooms_for_user(user_id) - if not rooms: + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. return - with (yield self._remote_edue_linearizer.queue(user_id)): - # If the prev id matches whats in our cache table, then we don't need - # to resync the users device list, otherwise we do. - resync = True - if len(prev_ids) == 1: - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) - logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) - if str(extremity) == str(prev_ids[0]): - resync = False + self._pending_updates.setdefault(user_id, []).append( + (device_id, stream_id, prev_ids, edu_content) + ) + + yield self._handle_device_updates(user_id) + + @measure_func("_incoming_device_list_update") + @defer.inlineCallbacks + def _handle_device_updates(self, user_id): + "Actually handle pending updates." + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + # Given a list of updates we check if we need to resync. This + # happens if we've missed updates. + resync = yield self._need_to_do_resync(user_id, pending_updates) if resync: # Fetch all devices for the user. - result = yield self.federation.query_user_devices(origin, user_id) + origin = get_domain_from_id(user_id) + try: + result = yield self.federation.query_user_devices(origin, user_id) + except NotRetryingDestination: + # TODO: Remember that we are now out of sync and try again + # later + logger.warn( + "Failed to handle device list update for %s," + " we're not retrying the remote", + user_id, + ) + # We abort on exceptions rather than accepting the update + # as otherwise synapse will 'forget' that its device list + # is out of date. If we bail then we will retry the resync + # next time we get a device list update for this user_id. + # This makes it more likely that the device lists will + # eventually become consistent. + return + except FederationDeniedError as e: + logger.info(e) + return + except Exception: + # TODO: Remember that we are now out of sync and try again + # later + logger.exception( + "Failed to handle device list update for %s", user_id + ) + return + stream_id = result["stream_id"] devices = result["devices"] yield self.store.update_remote_device_list_cache( user_id, devices, stream_id, ) device_ids = [device["device_id"] for device in devices] - yield self.notify_device_update(user_id, device_ids) + yield self.device_handler.notify_device_update(user_id, device_ids) else: # Simply update the single device, since we know that is the only # change (becuase of the single prev_id matching the current cache) - content = dict(edu_content) - for key in ("user_id", "device_id", "stream_id", "prev_ids"): - content.pop(key, None) - yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + for device_id, stream_id, prev_ids, content in pending_updates: + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + + yield self.device_handler.notify_device_update( + user_id, [device_id for device_id, _, _, _ in pending_updates] ) - yield self.notify_device_update(user_id, [device_id]) - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + self._seen_updates.setdefault(user_id, set()).update( + stream_id for _, stream_id, _, _ in pending_updates + ) @defer.inlineCallbacks - def user_left_room(self, user, room_id): - user_id = user.to_string() - rooms = yield self.store.get_rooms_for_user(user_id) - if not rooms: - # We no longer share rooms with this user, so we'll no longer - # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + def _need_to_do_resync(self, user_id, updates): + """Given a list of updates for a user figure out if we need to do a full + resync, or whether we have enough data that we can just apply the delta. + """ + seen_updates = self._seen_updates.get(user_id, set()) + extremity = yield self.store.get_device_list_last_stream_id_for_remote( + user_id + ) -def _update_device_from_client_ips(device, client_ips): - ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + stream_id_in_updates = set() # stream_ids in updates list + for _, stream_id, prev_ids, _ in updates: + if not prev_ids: + # We always do a resync if there are no previous IDs + defer.returnValue(True) + + for prev_id in prev_ids: + if prev_id == extremity: + continue + elif prev_id in seen_updates: + continue + elif prev_id in stream_id_in_updates: + continue + else: + defer.returnValue(True) + + stream_id_in_updates.add(stream_id) + + defer.returnValue(False) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index f7fad15c62..f147a20b73 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.types import get_domain_from_id +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id, UserID from synapse.util.stringutils import random_string @@ -33,10 +34,10 @@ class DeviceMessageHandler(object): """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) @@ -52,6 +53,12 @@ class DeviceMessageHandler(object): message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + messages_by_device = { device_id: { "content": message_content, @@ -77,7 +84,8 @@ class DeviceMessageHandler(object): local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): messages_by_device = { device_id: { "content": message_content, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1b5317edf5..c5b6e75e03 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -34,12 +34,15 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() + self.event_creation_handler = hs.get_event_creation_handler() - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "directory", self.on_directory_query ) + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def _create_association(self, room_alias, room_id, servers=None, creator=None): # general association creation for both human users and app services @@ -73,6 +76,11 @@ class DirectoryHandler(BaseHandler): # association creation for human users # TODO(erikj): Do user auth. + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + raise SynapseError( + 403, "This user is not permitted to create this alias", + ) + can_create = yield self.can_modify_alias( room_alias, user_id=user_id @@ -175,6 +183,7 @@ class DirectoryHandler(BaseHandler): "room_alias": room_alias.to_string(), }, retry_on_dns_fail=False, + ignore_backoff=True, ) except CodeMessageException as e: logging.warn("Error retrieving alias") @@ -241,8 +250,7 @@ class DirectoryHandler(BaseHandler): def send_room_alias_update_event(self, requester, user_id, room_id): aliases = yield self.store.get_aliases_for_room(room_id) - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Aliases, @@ -264,8 +272,7 @@ class DirectoryHandler(BaseHandler): if not alias_event or alias_event.content.get("alias", "") != alias_str: return - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -326,6 +333,14 @@ class DirectoryHandler(BaseHandler): room_id (str) visibility (str): "public" or "private" """ + if not self.spam_checker.user_may_publish_room( + requester.user.to_string(), room_id + ): + raise AuthError( + 403, + "This user is not permitted to publish rooms to the room list" + ) + if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index e40495d1ab..25aec624af 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ujson as json +import simplejson as json import logging from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +from synapse.api.errors import ( + SynapseError, CodeMessageException, FederationDeniedError, +) +from synapse.types import get_domain_from_id, UserID +from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -30,15 +33,15 @@ logger = logging.getLogger(__name__) class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - self.federation.register_query_handler( + hs.get_federation_registry().register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -70,7 +73,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids else: remote_queries[user_id] = device_ids @@ -121,38 +125,23 @@ class E2eKeysHandler(object): def do_remote_query(destination): destination_query = remote_queries_not_in_cache[destination] try: - limiter = yield get_retry_limiter( - destination, self.clock, self.store + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout ) - with limiter: - remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout - ) for user_id, keys in remote_result["device_keys"].items(): if user_id in destination_query: results[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) - yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(do_remote_query)(destination) + yield make_deferred_yieldable(defer.gatherResults([ + run_in_background(do_remote_query, destination) for destination in remote_queries_not_in_cache - ])) + ], consumeErrors=True)) defer.returnValue({ "device_keys": results, "failures": failures, @@ -174,7 +163,8 @@ class E2eKeysHandler(object): result_dict = {} for user_id, device_ids in query.items(): - if not self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") @@ -217,7 +207,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: @@ -239,36 +230,31 @@ class E2eKeysHandler(object): def claim_client_keys(destination): device_keys = remote_queries[destination] try: - limiter = yield get_retry_limiter( - destination, self.clock, self.store + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout ) - with limiter: - remote_result = yield self.federation.claim_client_keys( - destination, - {"one_time_keys": device_keys}, - timeout=timeout - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) - yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(claim_client_keys)(destination) + yield make_deferred_yieldable(defer.gatherResults([ + run_in_background(claim_client_keys, destination) for destination in remote_queries - ])) + ], consumeErrors=True)) + + logger.info( + "Claimed one-time-keys: %s", + ",".join(( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in json_result.iteritems() + for device_id, device_keys in user_keys.iteritems() + for key_id, _ in device_keys.iteritems() + )), + ) defer.returnValue({ "one_time_keys": json_result, @@ -296,19 +282,8 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list + yield self._upload_one_time_keys_for_user( + user_id, device_id, time_now, one_time_keys, ) # the device should have been registered already, but it may have been @@ -316,8 +291,88 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - self.device_handler.check_device_registered(user_id, device_id) + yield self.device_handler.check_device_registered(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue({"one_time_key_counts": result}) + + @defer.inlineCallbacks + def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, + one_time_keys): + logger.info( + "Adding one_time_keys %r for device %r for user %r at %d", + one_time_keys.keys(), device_id, user_id, time_now, + ) + + # make a list of (alg, id, key) tuples + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, key_obj + )) + + # First we check if we have already persisted any of the keys. + existing_key_map = yield self.store.get_e2e_one_time_keys( + user_id, device_id, [k_id for _, k_id, _ in key_list] + ) + + new_keys = [] # Keys that we need to insert. (alg, id, json) tuples. + for algorithm, key_id, key in key_list: + ex_json = existing_key_map.get((algorithm, key_id), None) + if ex_json: + if not _one_time_keys_match(ex_json, key): + raise SynapseError( + 400, + ("One time key %s:%s already exists. " + "Old key: %s; new key: %r") % + (algorithm, key_id, ex_json, key) + ) + else: + new_keys.append((algorithm, key_id, encode_canonical_json(key))) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, new_keys + ) + + +def _exception_to_failure(e): + if isinstance(e, CodeMessageException): + return { + "status": e.code, "message": e.message, + } + + if isinstance(e, NotRetryingDestination): + return { + "status": 503, "message": "Not ready for retry", + } + + if isinstance(e, FederationDeniedError): + return { + "status": 403, "message": "Federation Denied", + } + + # include ConnectionRefused and other errors + # + # Note that some Exceptions (notably twisted's ResponseFailed etc) don't + # give a string for e.message, which simplejson then fails to serialize. + return { + "status": 503, "message": str(e.message), + } + + +def _one_time_keys_match(old_key_json, new_key): + old_key = json.loads(old_key_json) + + # if either is a string rather than an object, they must match exactly + if not isinstance(old_key, dict) or not isinstance(new_key, dict): + return old_key == new_key + + # otherwise, we strip off the 'signatures' if any, because it's legitimate + # for different upload attempts to have different signatures. + old_key.pop("signatures", None) + new_key_copy = dict(new_key) + new_key_copy.pop("signatures", None) + + return old_key == new_key_copy diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 996bfd0e23..f39233d846 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,24 +15,30 @@ # limitations under the License. """Contains handlers for federation events.""" + +import itertools +import logging +import sys + from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json +import six +from six.moves import http_client +from twisted.internet import defer from unpaddedbase64 import decode_base64 from ._base import BaseHandler from synapse.api.errors import ( AuthError, FederationError, StoreError, CodeMessageException, SynapseError, + FederationDeniedError, ) from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred -) +from synapse.util import unwrapFirstError, logcontext from synapse.util.metrics import measure_func from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor +from synapse.util.async import run_on_reactor, Linearizer from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, @@ -42,13 +49,8 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination -from synapse.push.action_generator import ActionGenerator from synapse.util.distributor import user_joined_room -from twisted.internet import defer - -import itertools -import logging logger = logging.getLogger(__name__) @@ -70,38 +72,268 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() - self.replication_layer = hs.get_replication_layer() + self.replication_layer = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() - - self.replication_layer.set_handler(self) + self.action_generator = hs.get_action_generator() + self.is_mine_id = hs.is_mine_id + self.pusher_pool = hs.get_pusherpool() + self.spam_checker = hs.get_spam_checker() + self.event_creation_handler = hs.get_event_creation_handler() # When joining a room we need to queue any events for that room up self.room_queues = {} + self._room_pdu_linearizer = Linearizer("fed_room_pdu") - @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): - """ Called by the ReplicationLayer when we have a new pdu. We need to - do auth checks and put it through the StateHandler. + @log_function + def on_receive_pdu(self, origin, pdu, get_missing=True): + """ Process a PDU received via a federation /send/ transaction, or + via backfill of missing prev_events + + Args: + origin (str): server which initiated the /send/ transaction. Will + be used to fetch missing events or state. + pdu (FrozenEvent): received PDU + get_missing (bool): True if we should fetch missing prev_events - auth_chain and state are None if we already have the necessary state - and prev_events in the db + Returns (Deferred): completes with None """ - event = pdu - logger.debug("Got event: %s", event.event_id) + # We reprocess pdus when we have seen them only as outliers + existing = yield self.get_persisted_pdu( + origin, pdu.event_id, do_auth=False + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + + already_seen = ( + existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() + ) + ) + if already_seen: + logger.debug("Already seen pdu %s", pdu.event_id) + return + + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + raise FederationError( + "ERROR", + err.code, + err.msg, + affected=pdu.event_id, + ) # If we are currently in the process of joining this room, then we # queue up events for later processing. - if event.room_id in self.room_queues: - self.room_queues[event.room_id].append((pdu, origin)) + if pdu.room_id in self.room_queues: + logger.info("Ignoring PDU %s for room %s from %s for now; join " + "in progress", pdu.event_id, pdu.room_id, origin) + self.room_queues[pdu.room_id].append((pdu, origin)) return - logger.debug("Processing event: %s", event.event_id) + # If we're no longer in the room just ditch the event entirely. This + # is probably an old server that has come back and thinks we're still + # in the room (or we've been rejoined to the room by a state reset). + # + # If we were never in the room then maybe our database got vaped and + # we should check if we *are* in fact in the room. If we are then we + # can magically rejoin the room. + is_in_room = yield self.auth.check_host_in_room( + pdu.room_id, + self.server_name + ) + if not is_in_room: + was_in_room = yield self.store.was_host_joined( + pdu.room_id, self.server_name, + ) + if was_in_room: + logger.info( + "Ignoring PDU %s for room %s from %s as we've left the room!", + pdu.event_id, pdu.room_id, origin, + ) + return + + state = None + + auth_chain = [] + + fetch_state = False + + # Get missing pdus if necessary. + if not pdu.internal_metadata.is_outlier(): + # We only backfill backwards to the min depth. + min_depth = yield self.get_min_depth_for_context( + pdu.room_id + ) + + logger.debug( + "_handle_new_pdu min_depth for %s: %d", + pdu.room_id, min_depth + ) + + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = yield self.store.have_seen_events(prevs) + + if min_depth and pdu.depth < min_depth: + # This is so that we don't notify the user about this + # message, to work around the fact that some events will + # reference really really old events we really don't want to + # send to the clients. + pdu.internal_metadata.outlier = True + elif min_depth and pdu.depth > min_depth: + if get_missing and prevs - seen: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring lock for room %r to fetch %d missing events: %r...", + pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], + ) + with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired lock for room %r to fetch %d missing events", + pdu.room_id, len(prevs - seen), + ) + + yield self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = yield self.store.have_seen_events(prevs) + + if not prevs - seen: + logger.info( + "Found all missing prev events for %s", pdu.event_id + ) + elif prevs - seen: + logger.info( + "Not fetching %d missing events for room %r,event %s: %r...", + len(prevs - seen), pdu.room_id, pdu.event_id, + list(prevs - seen)[:5], + ) + + if prevs - seen: + logger.info( + "Still missing %d events for room %r: %r...", + len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + ) + fetch_state = True - logger.debug("Event: %s", event) + if fetch_state: + # We need to get the state at this event, since we haven't + # processed all the prev events. + logger.debug( + "_handle_new_pdu getting state for %s", + pdu.room_id + ) + try: + state, auth_chain = yield self.replication_layer.get_state_for_room( + origin, pdu.room_id, pdu.event_id, + ) + except Exception: + logger.exception("Failed to get state for event: %s", pdu.event_id) + + yield self._process_received_pdu( + origin, + pdu, + state=state, + auth_chain=auth_chain, + ) + + @defer.inlineCallbacks + def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + """ + Args: + origin (str): Origin of the pdu. Will be called to get the missing events + pdu: received pdu + prevs (set(str)): List of event ids which we are missing + min_depth (int): Minimum depth of events to return. + """ + # We recalculate seen, since it may have changed. + seen = yield self.store.have_seen_events(prevs) + + if not prevs - seen: + return + + latest = yield self.store.get_latest_event_ids_in_room( + pdu.room_id + ) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest) + latest |= seen + + logger.info( + "Missing %d events for room %r pdu %s: %r...", + len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5] + ) + + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 + + missing_events = yield self.replication_layer.get_missing_events( + origin, + pdu.room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=10000, + ) + + logger.info( + "Got %d events: %r...", + len(missing_events), [e.event_id for e in missing_events[:5]] + ) + + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + + for e in missing_events: + logger.info("Handling found event %s", e.event_id) + yield self.on_receive_pdu( + origin, + e, + get_missing=False + ) + + @log_function + @defer.inlineCallbacks + def _process_received_pdu(self, origin, pdu, state, auth_chain): + """ Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + """ + event = pdu + + logger.debug("Processing event: %s", event) # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work @@ -140,9 +372,7 @@ class FederationHandler(BaseHandler): if auth_chain: event_ids |= {e.event_id for e in auth_chain} - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + seen_ids = yield self.store.have_seen_events(event_ids) if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication @@ -181,13 +411,6 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - # if we're receiving valid events from an origin, - # it's probably a good idea to mark it as not in retry-state - # for sending (although this is a bit of a leap) - retry_timings = yield self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings["retry_last_ts"]: - self.store.set_destination_retry_timings(origin, 0, 0) - room = yield self.store.get_room(event.room_id) if not room: @@ -206,11 +429,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) if event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -249,7 +471,7 @@ class FederationHandler(BaseHandler): def check_match(id): try: return server_name == get_domain_from_id(id) - except: + except Exception: return False # Parses mapping `event_id -> (type, state_key) -> state event_id` @@ -287,7 +509,7 @@ class FederationHandler(BaseHandler): continue try: domain = get_domain_from_id(ev.state_key) - except: + except Exception: continue if domain != server_name: @@ -314,9 +536,16 @@ class FederationHandler(BaseHandler): def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` - This will attempt to get more events from the remote. This may return - be successfull and still return no events if the other side has no new - events to offer. + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) """ if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") @@ -328,6 +557,16 @@ class FederationHandler(BaseHandler): extremities=extremities, ) + # ideally we'd sanity check the events here for excess prev_events etc, + # but it's hard to reject events at this point without completely + # breaking backfill in the same way that it is currently broken by + # events whose signature we cannot verify (#3121). + # + # So for now we accept the events anyway. #3124 tracks this. + # + # for ev in events: + # self._sanity_check_event(ev) + # Don't bother processing events we already have. seen_events = yield self.store.have_events_in_timeline( set(e.event_id for e in events) @@ -398,9 +637,10 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.replication_layer.get_pdu)( + logcontext.run_in_background( + self.replication_layer.get_pdu, [dest], event_id, outlier=True, @@ -420,7 +660,7 @@ class FederationHandler(BaseHandler): failed_to_fetch = missing_auth - set(auth_events) - seen_events = yield self.store.have_events( + seen_events = yield self.store.have_seen_events( set(auth_events.keys()) | set(state_events.keys()) ) @@ -526,7 +766,7 @@ class FederationHandler(BaseHandler): joined_domains[dom] = min(d, old_d) else: joined_domains[dom] = d - except: + except Exception: pass return sorted(joined_domains.items(), key=lambda d: d[1]) @@ -570,6 +810,9 @@ class FederationHandler(BaseHandler): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e) + continue except Exception as e: logger.exception( "Failed to backfill from %s because %s", @@ -592,10 +835,13 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") - states = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) - for e in event_ids - ])) + resolve = logcontext.preserve_fn( + self.state_handler.resolve_state_groups_for_events + ) + states = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], + consumeErrors=True, + )) states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( @@ -624,6 +870,38 @@ class FederationHandler(BaseHandler): defer.returnValue(False) + def _sanity_check_event(self, ev): + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev (synapse.events.EventBase): event to be checked + + Returns: None + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_events) > 20: + logger.warn("Rejecting event %s which has %i prev_events", + ev.event_id, len(ev.prev_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many prev_events", + ) + + if len(ev.auth_events) > 10: + logger.warn("Rejecting event %s which has %i auth_events", + ev.event_id, len(ev.auth_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many auth_events", + ) + @defer.inlineCallbacks def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -641,7 +919,11 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_event_auth(self, event_id): - auth = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + auth = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) for event in auth: event.signatures.update( @@ -670,8 +952,6 @@ class FederationHandler(BaseHandler): """ logger.debug("Joining %s to %s", joinee, room_id) - yield self.store.clean_room_for_join(room_id) - origin, event = yield self._make_and_verify_event( target_hosts, room_id, @@ -680,7 +960,15 @@ class FederationHandler(BaseHandler): content, ) + # This shouldn't happen, because the RoomMemberHandler has a + # linearizer lock which only allows one operation per user per room + # at a time - so this is just paranoia. + assert (room_id not in self.room_queues) + self.room_queues[room_id] = [] + + yield self.store.clean_room_for_join(room_id) + handled_events = set() try: @@ -714,7 +1002,7 @@ class FederationHandler(BaseHandler): room_creator_user_id="", is_public=False ) - except: + except Exception: # FIXME pass @@ -722,29 +1010,45 @@ class FederationHandler(BaseHandler): origin, auth_chain, state, event ) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[joinee] - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[joinee] + ) logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] - for p, origin in room_queue: - if p.event_id in handled_events: - continue + # we don't need to wait for the queued events to be processed - + # it's just a best-effort thing at this point. We do want to do + # them roughly in order, though, otherwise we'll end up making + # lots of requests for missing prev_events which we do actually + # have. Hence we fire off the deferred, but don't wait for it. - try: - self.on_receive_pdu(origin, p) - except: - logger.exception("Couldn't handle pdu") + logcontext.run_in_background(self._handle_queued_pdus, room_queue) defer.returnValue(True) @defer.inlineCallbacks + def _handle_queued_pdus(self, room_queue): + """Process PDUs which got queued up while we were busy send_joining. + + Args: + room_queue (list[FrozenEvent, str]): list of PDUs to be processed + and the servers that sent them + """ + for p, origin in room_queue: + try: + logger.info("Processing queued PDU %s which was received " + "while we were joining %s", p.event_id, p.room_id) + yield self.on_receive_pdu(origin, p) + except Exception as e: + logger.warn( + "Error handling queued PDU %s from %s: %s", + p.event_id, origin, e) + + @defer.inlineCallbacks @log_function def on_make_join_request(self, room_id, user_id): """ We've received a /make_join/ request, so we create a partial @@ -762,8 +1066,7 @@ class FederationHandler(BaseHandler): }) try: - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) except AuthError as e: @@ -791,9 +1094,19 @@ class FederationHandler(BaseHandler): ) event.internal_metadata.outlier = False - # Send this event on behalf of the origin server since they may not - # have an up to data view of the state of the room at this event so - # will not know which servers to send the event to. + # Send this event on behalf of the origin server. + # + # The reasons we have the destination server rather than the origin + # server send it are slightly mysterious: the origin server should have + # all the neccessary state once it gets the response to the send_join, + # so it could send the event itself if it wanted to. It may be that + # doing it this way reduces failure modes, or avoids certain attacks + # where a new server selectively tells a subset of the federation that + # it has joined. + # + # The fact is that, as of the current writing, Synapse doesn't send out + # the join event over federation after joining, and changing it now + # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin context, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -812,10 +1125,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: @@ -823,9 +1135,7 @@ class FederationHandler(BaseHandler): yield user_joined_room(self.distributor, user, event.room_id) state_ids = context.prev_state_ids.values() - auth_chain = yield self.store.get_auth_chain(set( - [event.event_id] + state_ids - )) + auth_chain = yield self.store.get_auth_chain(state_ids) state = yield self.store.get_events(context.prev_state_ids.values()) @@ -842,6 +1152,34 @@ class FederationHandler(BaseHandler): """ event = pdu + if event.state_key is None: + raise SynapseError(400, "The invite event did not have a state key") + + is_blocked = yield self.store.is_room_blocked(event.room_id) + if is_blocked: + raise SynapseError(403, "This room has been blocked on this server") + + if self.hs.config.block_non_admin_invites: + raise SynapseError(403, "This server does not accept room invites") + + if not self.spam_checker.user_may_invite( + event.sender, event.state_key, event.room_id, + ): + raise SynapseError( + 403, "This user is not permitted to send invites to this server/user" + ) + + membership = event.content.get("membership") + if event.type != EventTypes.Member or membership != Membership.INVITE: + raise SynapseError(400, "The event was not an m.room.member invite event") + + sender_domain = get_domain_from_id(event.sender) + if sender_domain != origin: + raise SynapseError(400, "The invite event was not from the server sending it") + + if not self.is_mine_id(event.state_key): + raise SynapseError(400, "The invite event must be for this server") + event.internal_metadata.outlier = True event.internal_metadata.invite_from_remote = True @@ -861,48 +1199,38 @@ class FederationHandler(BaseHandler): ) target_user = UserID.from_string(event.state_key) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[target_user], + ) defer.returnValue(event) @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - try: - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) - except SynapseError: - raise - except CodeMessageException as e: - logger.warn("Failed to reject invite: %s", e) - raise SynapseError(500, "Failed to reject invite") - - # Try the host we successfully got a response to /make_join/ - # request first. + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + # Mark as outlier as we don't have any state for this event; we're not + # even in the room. + event.internal_metadata.outlier = True + event = self._sign_event(event) + + # Try the host that we succesfully called /make_leave/ on first for + # the /send_leave/ request. try: target_hosts.remove(origin) target_hosts.insert(0, origin) except ValueError: pass - try: - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) - except SynapseError: - raise - except CodeMessageException as e: - logger.warn("Failed to reject invite: %s", e) - raise SynapseError(500, "Failed to reject invite") + yield self.replication_layer.send_leave( + target_hosts, + event + ) context = yield self.state_handler.compute_event_context(event) @@ -978,8 +1306,7 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) @@ -1023,10 +1350,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) defer.returnValue(None) @@ -1061,7 +1387,7 @@ class FederationHandler(BaseHandler): for event in res: # We sign these again because there was a bug where we # incorrectly signed things the first time round - if self.hs.is_mine_id(event.event_id): + if self.is_mine_id(event.event_id): event.signatures.update( compute_event_signature( event, @@ -1096,7 +1422,7 @@ class FederationHandler(BaseHandler): if prev_id != event.event_id: results[(event.type, event.state_key)] = prev_id else: - del results[(event.type, event.state_key)] + results.pop((event.type, event.state_key), None) defer.returnValue(results.values()) else: @@ -1134,7 +1460,7 @@ class FederationHandler(BaseHandler): ) if event: - if self.hs.is_mine_id(event.event_id): + if self.is_mine_id(event.event_id): # FIXME: This is a temporary work around where we occasionally # return events slightly differently than when they were # originally signed @@ -1178,23 +1504,33 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - if not event.internal_metadata.is_outlier(): - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( - event, context + try: + if not event.internal_metadata.is_outlier() and not backfilled: + yield self.action_generator.handle_push_actions_for_event( + event, context + ) + + event_stream_id, max_stream_id = yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, ) + except: # noqa: E722, as we reraise the exception this is fine. + tp, value, tb = sys.exc_info() - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - ) + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) if not backfilled: # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id + logcontext.run_in_background( + self.pusher_pool.on_new_notifications, + event_stream_id, max_stream_id, ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1206,16 +1542,17 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield preserve_context_over_deferred(defer.gatherResults( + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self._prep_event)( + logcontext.run_in_background( + self._prep_event, origin, ev_info["event"], state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) for ev_info in event_infos - ] + ], consumeErrors=True, )) yield self.store.persist_events( @@ -1325,7 +1662,17 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): + """ + + Args: + origin: + event: + state: + auth_events: + Returns: + Deferred, which resolves to synapse.events.snapshot.EventContext + """ context = yield self.state_handler.compute_event_context( event, old_state=state, ) @@ -1362,7 +1709,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR - if event.type == EventTypes.GuestAccess: + if event.type == EventTypes.GuestAccess and not context.rejected: yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1379,7 +1726,11 @@ class FederationHandler(BaseHandler): pass # Now get the current auth_chain for the event. - local_auth_chain = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + local_auth_chain = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. @@ -1427,6 +1778,17 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def do_auth(self, origin, event, context, auth_events): + """ + + Args: + origin (str): + event (synapse.events.FrozenEvent): + context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->str]): + + Returns: + defer.Deferred[None] + """ # Check if we have all the auth events. current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) @@ -1437,7 +1799,8 @@ class FederationHandler(BaseHandler): event_key = None if event_auth_events - current_state: - have_events = yield self.store.have_events( + # TODO: can we use store.have_seen_events here instead? + have_events = yield self.store.get_seen_events_with_rejections( event_auth_events - current_state ) else: @@ -1460,12 +1823,12 @@ class FederationHandler(BaseHandler): origin, event.room_id, event.event_id ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] ) for e in remote_auth_chain: - if e.event_id in seen_remotes.keys(): + if e.event_id in seen_remotes: continue if e.event_id == event.event_id: @@ -1492,11 +1855,11 @@ class FederationHandler(BaseHandler): except AuthError: pass - have_events = yield self.store.have_events( + have_events = yield self.store.get_seen_events_with_rejections( [e_id for e_id, _ in event.auth_events] ) seen_events = set(have_events.keys()) - except: + except Exception: # FIXME: logger.exception("Failed to get auth chain") @@ -1509,18 +1872,18 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self.store.get_event)( + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False, ) for d in different_auth if d in have_events and not have_events[d] - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + ], consumeErrors=True) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) @@ -1539,16 +1902,9 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1572,7 +1928,9 @@ class FederationHandler(BaseHandler): auth_ids = yield self.auth.compute_auth_events( event, context.prev_state_ids ) - local_auth_chain = yield self.store.get_auth_chain(auth_ids) + local_auth_chain = yield self.store.get_auth_chain( + auth_ids, include_given=True + ) try: # 2. Get remote difference. @@ -1583,13 +1941,13 @@ class FederationHandler(BaseHandler): local_auth_chain, ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in result["auth_chain"]] ) # 3. Process any remote auth chain events we haven't seen. for ev in result["auth_chain"]: - if ev.event_id in seen_remotes.keys(): + if ev.event_id in seen_remotes: continue if ev.event_id == event.event_id: @@ -1619,23 +1977,16 @@ class FederationHandler(BaseHandler): except AuthError: pass - except: + except Exception: # FIXME: logger.exception("Failed to query auth chain") # 4. Look at rejects and their proofs. # TODO. - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) try: self.auth.check(event, auth_events=auth_events) @@ -1644,6 +1995,45 @@ class FederationHandler(BaseHandler): raise e @defer.inlineCallbacks + def _update_context_for_auth_events(self, event, context, auth_events, + event_key): + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event (Event): The event we're handling the context for + + context (synapse.events.snapshot.EventContext): event context + to be updated + + auth_events (dict[(str, str)->str]): Events to update in the event + context. + + event_key ((str, str)): (type, state_key) for the current event. + this will not be included in the current_state in the context. + """ + state_updates = { + k: a.event_id for k, a in auth_events.iteritems() + if k != event_key + } + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update(state_updates) + if context.delta_ids is not None: + context.delta_ids = dict(context.delta_ids) + context.delta_ids.update(state_updates) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.iteritems() + }) + context.state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=context.prev_group, + delta_ids=context.delta_ids, + current_state_ids=context.current_state_ids, + ) + + @defer.inlineCallbacks def construct_auth_difference(self, local_auth, remote_auth): """ Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth @@ -1686,7 +2076,7 @@ class FederationHandler(BaseHandler): def get_next(it, opt=None): try: return it.next() - except: + except Exception: return opt current_local = get_next(local_iter) @@ -1811,8 +2201,7 @@ class FederationHandler(BaseHandler): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1827,7 +2216,7 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) - member_handler = self.hs.get_handlers().room_member_handler + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) @@ -1840,10 +2229,17 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): + """Handle an exchange_third_party_invite request from a remote server + + The remote server will call this when it wants to turn a 3pid invite + into a normal m.room.member invite. + + Returns: + Deferred: resolves (to None) + """ builder = self.event_builder_factory.new(event_dict) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) @@ -1858,10 +2254,13 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) + # XXX we send the invite here, but send_membership_event also sends it, + # so we end up making two requests. I think this is redundant. returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) - member_handler = self.hs.get_handlers().room_member_handler + + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks @@ -1890,8 +2289,9 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event(builder=builder) + event, context = yield self.event_creation_handler.create_new_client_event( + builder=builder, + ) defer.returnValue((event, context)) @defer.inlineCallbacks diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py new file mode 100644 index 0000000000..977993e7d4 --- /dev/null +++ b/synapse/handlers/groups_local.py @@ -0,0 +1,471 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id + +import logging + +logger = logging.getLogger(__name__) + + +def _create_rerouter(func_name): + """Returns a function that looks at the group id and calls the function + on federation or the local group server if the group is local + """ + def f(self, group_id, *args, **kwargs): + if self.is_mine_id(group_id): + return getattr(self.groups_server_handler, func_name)( + group_id, *args, **kwargs + ) + else: + destination = get_domain_from_id(group_id) + return getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + return f + + +class GroupsLocalHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.groups_server_handler = hs.get_groups_server_handler() + self.transport_client = hs.get_federation_transport_client() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.notifier = hs.get_notifier() + self.attestations = hs.get_groups_attestation_signing() + + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + # The following functions merely route the query to the local groups server + # or federation depending on if the group is local or remote + + get_group_profile = _create_rerouter("get_group_profile") + update_group_profile = _create_rerouter("update_group_profile") + get_rooms_in_group = _create_rerouter("get_rooms_in_group") + + get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") + + add_room_to_group = _create_rerouter("add_room_to_group") + update_room_in_group = _create_rerouter("update_room_in_group") + remove_room_from_group = _create_rerouter("remove_room_from_group") + + update_group_summary_room = _create_rerouter("update_group_summary_room") + delete_group_summary_room = _create_rerouter("delete_group_summary_room") + + update_group_category = _create_rerouter("update_group_category") + delete_group_category = _create_rerouter("delete_group_category") + get_group_category = _create_rerouter("get_group_category") + get_group_categories = _create_rerouter("get_group_categories") + + update_group_summary_user = _create_rerouter("update_group_summary_user") + delete_group_summary_user = _create_rerouter("delete_group_summary_user") + + update_group_role = _create_rerouter("update_group_role") + delete_group_role = _create_rerouter("delete_group_role") + get_group_role = _create_rerouter("get_group_role") + get_group_roles = _create_rerouter("get_group_roles") + + set_group_join_policy = _create_rerouter("set_group_join_policy") + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the group summary for a group. + + If the group is remote we check that the users have valid attestations. + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_group_summary( + group_id, requester_user_id + ) + else: + res = yield self.transport_client.get_group_summary( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + group_server_name = get_domain_from_id(group_id) + + # Loop through the users and validate the attestations. + chunk = res["users_section"]["users"] + valid_users = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_users.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["users_section"]["users"] = valid_users + + res["users_section"]["users"].sort(key=lambda e: e.get("order", 0)) + res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0)) + + # Add `is_publicised` flag to indicate whether the user has publicised their + # membership of the group on their profile + result = yield self.store.get_publicised_groups_for_user(requester_user_id) + is_publicised = group_id in result + + res.setdefault("user", {})["is_publicised"] = is_publicised + + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, content): + """Create a group + """ + + logger.info("Asking to create group with ID: %r", group_id) + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.create_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + content["user_profile"] = yield self.profile_handler.get_profile(user_id) + + res = yield self.transport_client.create_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + is_publicised = content.get("publicise", False) + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=True, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get users in a group + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_users_in_group( + group_id, requester_user_id + ) + defer.returnValue(res) + + group_server_name = get_domain_from_id(group_id) + + res = yield self.transport_client.get_users_in_group( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + chunk = res["chunk"] + valid_entries = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_entries.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["chunk"] = valid_entries + + defer.returnValue(res) + + @defer.inlineCallbacks + def join_group(self, group_id, user_id, content): + """Request to join a group + """ + if self.is_mine_id(group_id): + yield self.groups_server_handler.join_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.join_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def accept_invite(self, group_id, user_id, content): + """Accept an invite to a group + """ + if self.is_mine_id(group_id): + yield self.groups_server_handler.accept_invite( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.accept_group_invite( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite(self, group_id, user_id, requester_user_id, config): + """Invite a user to a group + """ + content = { + "requester_user_id": requester_user_id, + "config": config, + } + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + else: + res = yield self.transport_client.invite_to_group( + get_domain_from_id(group_id), group_id, user_id, requester_user_id, + content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def on_invite(self, group_id, user_id, content): + """One of our users were invited to a group + """ + # TODO: Support auto join and rejection + + if not self.is_mine_id(user_id): + raise SynapseError(400, "User not on this server") + + local_profile = {} + if "profile" in content: + if "name" in content["profile"]: + local_profile["name"] = content["profile"]["name"] + if "avatar_url" in content["profile"]: + local_profile["avatar_url"] = content["profile"]["avatar_url"] + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="invite", + content={"profile": local_profile, "inviter": content["inviter"]}, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + try: + user_profile = yield self.profile_handler.get_profile(user_id) + except Exception as e: + logger.warn("No profile for user %s: %s", user_id, e) + user_profile = {} + + defer.returnValue({"state": "invite", "user_profile": user_profile}) + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from a group + """ + if user_id == requester_user_id: + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + # TODO: Should probably remember that we tried to leave so that we can + # retry if the group server is currently down. + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + else: + content["requester_user_id"] = requester_user_id + res = yield self.transport_client.remove_user_from_group( + get_domain_from_id(group_id), group_id, requester_user_id, + user_id, content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def user_removed_from_group(self, group_id, user_id, content): + """One of our users was removed/kicked from a group + """ + # TODO: Check if user in group + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + @defer.inlineCallbacks + def get_joined_groups(self, user_id): + group_ids = yield self.store.get_joined_groups(user_id) + defer.returnValue({"groups": group_ids}) + + @defer.inlineCallbacks + def get_publicised_groups_for_user(self, user_id): + if self.hs.is_mine_id(user_id): + result = yield self.store.get_publicised_groups_for_user(user_id) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + result.extend(app_service.get_groups_for_user(user_id)) + + defer.returnValue({"groups": result}) + else: + bulk_result = yield self.transport_client.bulk_get_publicised_groups( + get_domain_from_id(user_id), [user_id], + ) + result = bulk_result.get("users", {}).get(user_id) + # TODO: Verify attestations + defer.returnValue({"groups": result}) + + @defer.inlineCallbacks + def bulk_get_publicised_groups(self, user_ids, proxy=True): + destinations = {} + local_users = set() + + for user_id in user_ids: + if self.hs.is_mine_id(user_id): + local_users.add(user_id) + else: + destinations.setdefault( + get_domain_from_id(user_id), set() + ).add(user_id) + + if not proxy and destinations: + raise SynapseError(400, "Some user_ids are not local") + + results = {} + failed_results = [] + for destination, dest_user_ids in destinations.iteritems(): + try: + r = yield self.transport_client.bulk_get_publicised_groups( + destination, list(dest_user_ids), + ) + results.update(r["users"]) + except Exception: + failed_results.extend(dest_user_ids) + + for uid in local_users: + results[uid] = yield self.store.get_publicised_groups_for_user( + uid + ) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + results[uid].extend(app_service.get_groups_for_user(uid)) + + defer.returnValue({"users": results}) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 559e5d5a71..91a0898860 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,20 @@ # limitations under the License. """Utilities for interacting with Identity Servers""" + +import logging + +import simplejson as json + from twisted.internet import defer from synapse.api.errors import ( - CodeMessageException + MatrixCodeMessageException, CodeMessageException ) from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.api.errors import SynapseError, Codes -import json -import logging - logger = logging.getLogger(__name__) @@ -89,6 +92,9 @@ class IdentityHandler(BaseHandler): ), {'sid': creds['sid'], 'client_secret': client_secret} ) + except MatrixCodeMessageException as e: + logger.info("getValidated3pid failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) except CodeMessageException as e: data = json.loads(e.msg) @@ -150,7 +156,7 @@ class IdentityHandler(BaseHandler): params.update(kwargs) try: - data = yield self.http_client.post_urlencoded_get_json( + data = yield self.http_client.post_json_get_json( "https://%s%s" % ( id_server, "/_matrix/identity/api/v1/validate/email/requestToken" @@ -158,6 +164,46 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) + except MatrixCodeMessageException as e: + logger.info("Proxied requestToken failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) + except CodeMessageException as e: + logger.info("Proxied requestToken failed: %r", e) + raise e + + @defer.inlineCallbacks + def requestMsisdnToken( + self, id_server, country, phone_number, + client_secret, send_attempt, **kwargs + ): + yield run_on_reactor() + + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, + Codes.SERVER_NOT_TRUSTED + ) + + params = { + 'country': country, + 'phone_number': phone_number, + 'client_secret': client_secret, + 'send_attempt': send_attempt, + } + params.update(kwargs) + + try: + data = yield self.http_client.post_json_get_json( + "https://%s%s" % ( + id_server, + "/_matrix/identity/api/v1/validate/msisdn/requestToken" + ), + params + ) + defer.returnValue(data) + except MatrixCodeMessageException as e: + logger.info("Proxied requestToken failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) except CodeMessageException as e: logger.info("Proxied requestToken failed: %r", e) raise e diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e0ade4c164..cd33a86599 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator +from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import ( UserID, StreamToken, @@ -26,7 +27,7 @@ from synapse.types import ( from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -162,10 +163,11 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield preserve_context_over_deferred( + (messages, token), current_state = yield make_deferred_yieldable( defer.gatherResults( [ - preserve_fn(self.store.get_recent_events_for_room)( + run_in_background( + self.store.get_recent_events_for_room, event.room_id, limit=limit, end_token=room_end_token, @@ -213,7 +215,7 @@ class InitialSyncHandler(BaseHandler): }) d["account_data"] = account_data_events - except: + except Exception: logger.exception("Failed to get snapshot") yield concurrently_execute(handle_room, room_list, 10) @@ -225,9 +227,17 @@ class InitialSyncHandler(BaseHandler): "content": content, }) + now = self.clock.time_msec() + ret = { "rooms": rooms_ret, - "presence": presence, + "presence": [ + { + "type": "m.presence", + "content": format_user_presence_state(event, now), + } + for event in presence + ], "account_data": account_data_events, "receipts": receipt, "end": now_token.to_string(), @@ -382,9 +392,10 @@ class InitialSyncHandler(BaseHandler): presence, receipts, (messages, token) = yield defer.gatherResults( [ - preserve_fn(get_presence)(), - preserve_fn(get_receipts)(), - preserve_fn(self.store.get_recent_events_for_room)( + run_in_background(get_presence), + run_in_background(get_receipts), + run_in_background( + self.store.get_recent_events_for_room, room_id, limit=limit, end_token=now_token.room_key, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a498af5a2..b793fc4df7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 - 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,31 +13,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import simplejson +import sys -from twisted.internet import defer +from canonicaljson import encode_canonical_json +import six +from twisted.internet import defer, reactor +from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError +from synapse.api.constants import EventTypes, Membership, MAX_DEPTH +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.push.action_generator import ActionGenerator from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client +from synapse.replication.http.send_event import send_event_to_master from ._base import BaseHandler -from canonicaljson import encode_canonical_json +logger = logging.getLogger(__name__) -import logging -import random -logger = logging.getLogger(__name__) +class PurgeStatus(object): + """Object tracking the status of a purge request + + This class contains information on the progress of a purge request, for + return by get_purge_status. + + Attributes: + status (int): Tracks whether this request has completed. One of + STATUS_{ACTIVE,COMPLETE,FAILED} + """ + + STATUS_ACTIVE = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + + STATUS_TEXT = { + STATUS_ACTIVE: "active", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + } + + def __init__(self): + self.status = PurgeStatus.STATUS_ACTIVE + + def asdict(self): + return { + "status": PurgeStatus.STATUS_TEXT[self.status] + } class MessageHandler(BaseHandler): @@ -46,25 +80,89 @@ class MessageHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() - self.validator = EventValidator() self.pagination_lock = ReadWriteLock() + self._purges_in_progress_by_room = set() + # map from purge id to PurgeStatus + self._purges_by_id = {} - # We arbitrarily limit concurrent event creation for a room to 5. - # This is to stop us from diverging history *too* much. - self.limiter = Limiter(max_count=5) + def start_purge_history(self, room_id, topological_ordering, + delete_local_events=False): + """Start off a history purge on a room. + + Args: + room_id (str): The room to purge from + + topological_ordering (int): minimum topo ordering to preserve + delete_local_events (bool): True to delete local events as well as + remote ones + + Returns: + str: unique ID for this purge transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, + "History purge already in progress for %s" % (room_id, ), + ) + + purge_id = random_string(16) + + # we log the purge_id here so that it can be tied back to the + # request id in the log lines. + logger.info("[purge] starting purge_id %s", purge_id) + + self._purges_by_id[purge_id] = PurgeStatus() + run_in_background( + self._purge_history, + purge_id, room_id, topological_ordering, delete_local_events, + ) + return purge_id @defer.inlineCallbacks - def purge_history(self, room_id, event_id): - event = yield self.store.get_event(event_id) + def _purge_history(self, purge_id, room_id, topological_ordering, + delete_local_events): + """Carry out a history purge on a room. - if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + Args: + purge_id (str): The id for this purge + room_id (str): The room to purge from + topological_ordering (int): minimum topo ordering to preserve + delete_local_events (bool): True to delete local events as well as + remote ones - depth = event.depth + Returns: + Deferred + """ + self._purges_in_progress_by_room.add(room_id) + try: + with (yield self.pagination_lock.write(room_id)): + yield self.store.purge_history( + room_id, topological_ordering, delete_local_events, + ) + logger.info("[purge] complete") + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE + except Exception: + logger.error("[purge] failed: %s", Failure().getTraceback().rstrip()) + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the purge from the list 24 hours after it completes + def clear_purge(): + del self._purges_by_id[purge_id] + reactor.callLater(24 * 3600, clear_purge) + + def get_purge_status(self, purge_id): + """Get the current status of an active purge - with (yield self.pagination_lock.write(room_id)): - yield self.store.delete_old_state(room_id, depth) + Args: + purge_id (str): purge_id returned by start_purge_history + + Returns: + PurgeStatus|None + """ + return self._purges_by_id.get(purge_id) @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, @@ -175,7 +273,167 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): + def get_room_data(self, user_id=None, room_id=None, + event_type=None, state_key="", is_guest=False): + """ Get data from a room. + + Args: + event : The room path event + Returns: + The path data content. + Raises: + SynapseError if something went wrong. + """ + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) + + if membership == Membership.JOIN: + data = yield self.state_handler.get_current_state( + room_id, event_type, state_key + ) + elif membership == Membership.LEAVE: + key = (event_type, state_key) + room_state = yield self.store.get_state_for_events( + [membership_event_id], [key] + ) + data = room_state[membership_event_id].get(key) + + defer.returnValue(data) + + @defer.inlineCallbacks + def _check_in_room_or_world_readable(self, room_id, user_id): + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + return + except AuthError: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + + @defer.inlineCallbacks + def get_state_events(self, user_id, room_id, is_guest=False): + """Retrieve all state events for a given room. If the user is + joined to the room then return the current state. If the user has + left the room return the state events from when they left. + + Args: + user_id(str): The user requesting state events. + room_id(str): The room ID to get all state events from. + Returns: + A list of dicts representing state events. [{}, {}, {}] + """ + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) + + if membership == Membership.JOIN: + room_state = yield self.state_handler.get_current_state(room_id) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], None + ) + room_state = room_state[membership_event_id] + + now = self.clock.time_msec() + defer.returnValue( + [serialize_event(c, now) for c in room_state.values()] + ) + + @defer.inlineCallbacks + def get_joined_members(self, requester, room_id): + """Get all the joined members in the room and their profile information. + + If the user has left the room return the state events from when they left. + + Args: + requester(Requester): The user requesting state events. + room_id(str): The room ID to get all state events from. + Returns: + A dict of user_id to profile info + """ + user_id = requester.user.to_string() + if not requester.app_service: + # We check AS auth after fetching the room membership, as it + # requires us to pull out all joined members anyway. + membership, _ = yield self._check_in_room_or_world_readable( + room_id, user_id + ) + if membership != Membership.JOIN: + raise NotImplementedError( + "Getting joined members after leaving is not implemented" + ) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + # If this is an AS, double check that they are allowed to see the members. + # This can either be because the AS user is in the room or becuase there + # is a user in the room that the AS is "interested in" + if requester.app_service and user_id not in users_with_profile: + for uid in users_with_profile: + if requester.app_service.is_interested_in_user(uid): + break + else: + # Loop fell through, AS has no interested users in room + raise AuthError(403, "Appservice not in room") + + defer.returnValue({ + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in users_with_profile.iteritems() + }) + + +class EventCreationHandler(object): + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.profile_handler = hs.get_profile_handler() + self.event_builder_factory = hs.get_event_builder_factory() + self.server_name = hs.hostname + self.ratelimiter = hs.get_ratelimiter() + self.notifier = hs.get_notifier() + self.config = hs.config + + self.http_client = hs.get_simple_http_client() + + # This is only used to get at ratelimit function, and maybe_kick_guest_users + self.base_handler = BaseHandler(hs) + + self.pusher_pool = hs.get_pusherpool() + + # We arbitrarily limit concurrent event creation for a room to 5. + # This is to stop us from diverging history *too* much. + self.limiter = Limiter(max_count=5) + + self.action_generator = hs.get_action_generator() + + self.spam_checker = hs.get_spam_checker() + + @defer.inlineCallbacks + def create_event(self, requester, event_dict, token_id=None, txn_id=None, + prev_events_and_hashes=None): """ Given a dict from a client, create a new event. @@ -185,49 +443,56 @@ class MessageHandler(BaseHandler): Adds display names to Join membership events. Args: + requester event_dict (dict): An entire event token_id (str) txn_id (str) - prev_event_ids (list): The prev event ids to use when creating the event + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. Returns: Tuple of created event (FrozenEvent), Context """ builder = self.event_builder_factory.new(event_dict) - with (yield self.limiter.queue(builder.room_id)): - self.validator.validate_new(builder) - - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) - - if membership in {Membership.JOIN, Membership.INVITE}: - # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler - content = builder.content - - try: - if "displayname" not in content: - content["displayname"] = yield profile.get_displayname(target) - if "avatar_url" not in content: - content["avatar_url"] = yield profile.get_avatar_url(target) - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", - target, e - ) + self.validator.validate_new(builder) + + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + target = UserID.from_string(builder.state_key) + + if membership in {Membership.JOIN, Membership.INVITE}: + # If event doesn't include a display name, add one. + profile = self.profile_handler + content = builder.content + + try: + if "displayname" not in content: + content["displayname"] = yield profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = yield profile.get_avatar_url(target) + except Exception as e: + logger.info( + "Failed to get profile information for %r: %s", + target, e + ) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if token_id is not None: + builder.internal_metadata.token_id = token_id - if txn_id is not None: - builder.internal_metadata.txn_id = txn_id + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id - event, context = yield self._create_new_client_event( - builder=builder, - prev_event_ids=prev_event_ids, - ) + event, context = yield self.create_new_client_event( + builder=builder, + requester=requester, + prev_events_and_hashes=prev_events_and_hashes, + ) defer.returnValue((event, context)) @@ -248,21 +513,6 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) - # We check here if we are currently being rate limited, so that we - # don't do unnecessary work. We check again just before we actually - # send the event. - time_now = self.clock.time() - allowed, time_allowed = self.ratelimiter.send_message( - event.sender, time_now, - msg_rate_hz=self.hs.config.rc_messages_per_second, - burst_count=self.hs.config.rc_message_burst_count, - update=False, - ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), - ) - user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) @@ -279,12 +529,6 @@ class MessageHandler(BaseHandler): ratelimit=ratelimit, ) - if event.type == EventTypes.Message: - presence = self.hs.get_presence_handler() - # We don't want to block sending messages on any presence code. This - # matters as sometimes presence code can take a while. - preserve_fn(presence.bump_presence_active_time)(user) - @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ @@ -318,144 +562,87 @@ class MessageHandler(BaseHandler): See self.create_event and self.send_nonmember_event. """ - event, context = yield self.create_event( - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id - ) - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, - ) - defer.returnValue(event) - - @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): - """ Get data from a room. - - Args: - event : The room path event - Returns: - The path data content. - Raises: - SynapseError if something went wrong. - """ - membership, membership_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) - if membership == Membership.JOIN: - data = yield self.state_handler.get_current_state( - room_id, event_type, state_key + # We limit the number of concurrent event sends in a room so that we + # don't fork the DAG too much. If we don't limit then we can end up in + # a situation where event persistence can't keep up, causing + # extremities to pile up, which in turn leads to state resolution + # taking longer. + with (yield self.limiter.queue(event_dict["room_id"])): + event, context = yield self.create_event( + requester, + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id ) - elif membership == Membership.LEAVE: - key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] - ) - data = room_state[membership_event_id].get(key) - defer.returnValue(data) + spam_error = self.spam_checker.check_event_for_spam(event) + if spam_error: + if not isinstance(spam_error, basestring): + spam_error = "Spam is not permitted here" + raise SynapseError( + 403, spam_error, Codes.FORBIDDEN + ) - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): - try: - # check_user_was_in_room will return the most recent membership - # event for the user if: - # * The user is a non-guest user, and was ever in the room - # * The user is a guest user, and has joined the room - # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - defer.returnValue((member_event.membership, member_event.event_id)) - return - except AuthError: - visibility = yield self.state_handler.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) - if ( - visibility and - visibility.content["history_visibility"] == "world_readable" - ): - defer.returnValue((Membership.JOIN, None)) - return - raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + yield self.send_nonmember_event( + requester, + event, + context, + ratelimit=ratelimit, ) + defer.returnValue(event) + @measure_func("create_new_client_event") @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): - """Retrieve all state events for a given room. If the user is - joined to the room then return the current state. If the user has - left the room return the state events from when they left. + def create_new_client_event(self, builder, requester=None, + prev_events_and_hashes=None): + """Create a new event for a local client Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. + builder (EventBuilder): + + requester (synapse.types.Requester|None): + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. + Returns: - A list of dicts representing state events. [{}, {}, {}] + Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] """ - membership, membership_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) - if membership == Membership.JOIN: - room_state = yield self.state_handler.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if prev_events_and_hashes is not None: + assert len(prev_events_and_hashes) <= 10, \ + "Attempting to create an event with %i prev_events" % ( + len(prev_events_and_hashes), ) - room_state = room_state[membership_event_id] - - now = self.clock.time_msec() - defer.returnValue( - [serialize_event(c, now) for c in room_state.values()] - ) - - @measure_func("_create_new_client_event") - @defer.inlineCallbacks - def _create_new_client_event(self, builder, prev_event_ids=None): - if prev_event_ids: - prev_events = yield self.store.add_event_hashes(prev_event_ids) - prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) - depth = prev_max_depth + 1 else: - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - # We want to limit the max number of prev events we point to in our - # new event - if len(latest_ret) > 10: - # Sort by reverse depth, so we point to the most recent. - latest_ret.sort(key=lambda a: -a[2]) - new_latest_ret = latest_ret[:5] - - # We also randomly point to some of the older events, to make - # sure that we don't completely ignore the older events. - if latest_ret[5:]: - sample_size = min(5, len(latest_ret[5:])) - new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) - latest_ret = new_latest_ret - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 + prev_events_and_hashes = \ + yield self.store.get_prev_events_for_room(builder.room_id) + + if prev_events_and_hashes: + depth = max([d for _, _, d in prev_events_and_hashes]) + 1 + # we cap depth of generated events, to ensure that they are not + # rejected by other servers (and so that they can be persisted in + # the db) + depth = min(depth, MAX_DEPTH) + else: + depth = 1 - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in prev_events_and_hashes + ] builder.prev_events = prev_events builder.depth = depth - state_handler = self.state_handler - - context = yield state_handler.compute_event_context(builder) + context = yield self.state.compute_event_context(builder) + if requester: + context.app_service = requester.app_service if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( @@ -488,12 +675,21 @@ class MessageHandler(BaseHandler): event, context, ratelimit=True, - extra_users=[] + extra_users=[], ): - # We now need to go and hit out to wherever we need to hit out to. + """Processes a new event. This includes checking auth, persisting it, + notifying users, sending to remote servers, etc. - if ratelimit: - self.ratelimit(requester) + If called from a worker will hit out to the master process for final + processing. + + Args: + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ try: yield self.auth.check_from_context(event, context) @@ -501,7 +697,72 @@ class MessageHandler(BaseHandler): logger.warn("Denying new event %r because %s", event, err) raise err - yield self.maybe_kick_guest_users(event, context) + # Ensure that we can round trip before trying to persist in db + try: + dump = frozendict_json_encoder.encode(event.content) + simplejson.loads(dump) + except Exception: + logger.exception("Failed to encode content: %r", event.content) + raise + + yield self.action_generator.handle_push_actions_for_event( + event, context + ) + + try: + # If we're a worker we need to hit out to the master. + if self.config.worker_app: + yield send_event_to_master( + self.http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + event=event, + context=context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + return + + yield self.persist_and_notify_client_event( + requester, + event, + context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + except: # noqa: E722, as we reraise the exception this is fine. + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + tp, value, tb = sys.exc_info() + + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) + + @defer.inlineCallbacks + def persist_and_notify_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[], + ): + """Called when we have fully built the event, have already + calculated the push actions for the event, and checked auth. + + This should only be run on master. + """ + assert not self.config.worker_app + + if ratelimit: + yield self.base_handler.ratelimit(requester) + + yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Check the alias is acually valid (at this time at least) @@ -531,9 +792,9 @@ class MessageHandler(BaseHandler): state_to_include_ids = [ e_id - for k, e_id in context.current_state_ids.items() + for k, e_id in context.current_state_ids.iteritems() if k[0] in self.hs.config.room_invite_state_types - or k[0] == EventTypes.Member and k[1] == event.sender + or k == (EventTypes.Member, event.sender) ] state_to_include = yield self.store.get_events(state_to_include_ids) @@ -545,7 +806,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for e in state_to_include.values() + for e in state_to_include.itervalues() ] invitee = UserID.from_string(event.state_key) @@ -594,30 +855,39 @@ class MessageHandler(BaseHandler): "Changing the room create event is forbidden", ) - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( - event, context - ) - (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context ) # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + run_in_background( + self.pusher_pool.on_new_notifications, event_stream_id, max_stream_id ) @defer.inlineCallbacks def _notify(): yield run_on_reactor() - yield self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + try: + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room event") - preserve_fn(_notify)() + run_in_background(_notify) - # If invite, remove room_state from unsigned before sending. - event.unsigned.pop("invite_room_state", None) + if event.type == EventTypes.Message: + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + run_in_background(self._bump_active_time, requester.user) + + @defer.inlineCallbacks + def _bump_active_time(self, user): + try: + presence = self.hs.get_presence_handler() + yield presence.bump_presence_active_time(user) + except Exception: + logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index da610e430f..585f3e4da2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -29,7 +29,9 @@ from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import preserve_fn +from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.async import Linearizer +from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -91,29 +93,30 @@ class PresenceHandler(object): self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.replication = hs.get_replication_layer() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() - self.replication.register_edu_handler( + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( "m.presence", self.incoming_presence ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), @@ -186,6 +189,7 @@ class PresenceHandler(object): # process_id to millisecond timestamp last updated. self.external_process_to_current_syncs = {} self.external_process_last_updated_ms = {} + self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to @@ -251,6 +255,14 @@ class PresenceHandler(object): logger.info("Finished _persist_unpersisted_changes") @defer.inlineCallbacks + def _update_states_and_catch_exception(self, new_states): + try: + res = yield self._update_states(new_states) + defer.returnValue(res) + except Exception: + logger.exception("Error updating presence") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -315,11 +327,7 @@ class PresenceHandler(object): if to_federation_ping: federation_presence_out_counter.inc_by(len(to_federation_ping)) - _, _, hosts_to_states = yield self._get_interested_parties( - to_federation_ping.values() - ) - - self._push_to_remotes(hosts_to_states) + self._push_to_remotes(to_federation_ping.values()) def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as @@ -364,8 +372,8 @@ class PresenceHandler(object): now=now, ) - preserve_fn(self._update_states)(changes) - except: + run_in_background(self._update_states_and_catch_exception, changes) + except Exception: logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks @@ -422,20 +430,23 @@ class PresenceHandler(object): @defer.inlineCallbacks def _end(): - if affect_presence: + try: self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) yield self._update_states([prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec(), )]) + except Exception: + logger.exception("Error updating presence after sync") @contextmanager def _user_syncing(): try: yield finally: - preserve_fn(_end)() + if affect_presence: + run_in_background(_end) defer.returnValue(_user_syncing()) @@ -508,6 +519,73 @@ class PresenceHandler(object): self.external_process_to_current_syncs[process_id] = syncing_user_ids @defer.inlineCallbacks + def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + """Update the syncing users for an external process as a delta. + + Args: + process_id (str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + user_id (str): The user who has started or stopped syncing + is_syncing (bool): Whether or not the user is now syncing + sync_time_msec(int): Time in ms when the user was last syncing + """ + with (yield self.external_sync_linearizer.queue(process_id)): + prev_state = yield self.current_state_for_user(user_id) + + process_presence = self.external_process_to_current_syncs.setdefault( + process_id, set() + ) + + updates = [] + if is_syncing and user_id not in process_presence: + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec, + )) + process_presence.add(user_id) + elif user_id in process_presence: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec, + )) + + if not is_syncing: + process_presence.discard(user_id) + + if updates: + yield self._update_states(updates) + + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + + @defer.inlineCallbacks + def update_external_syncs_clear(self, process_id): + """Marks all users that had been marked as syncing by a given process + as offline. + + Used when the process has stopped/disappeared. + """ + with (yield self.external_sync_linearizer.queue(process_id)): + process_presence = self.external_process_to_current_syncs.pop( + process_id, set() + ) + prev_states = yield self.current_state_for_users(process_presence) + time_now_ms = self.clock.time_msec() + + yield self._update_states([ + prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + ) + for prev_state in prev_states.itervalues() + ]) + self.external_process_last_updated_ms.pop(process_id, None) + + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. """ @@ -526,14 +604,14 @@ class PresenceHandler(object): for user_id in user_ids } - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) @@ -545,89 +623,39 @@ class PresenceHandler(object): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states, calculate_remote_hosts=True): - """Given a list of states return which entities (rooms, users, servers) - are interested in the given states. - - Returns: - 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, - with each item being a dict of `entity_name` -> `[UserPresenceState]` - """ - room_ids_to_states = {} - users_to_states = {} - for state in states: - events = yield self.store.get_rooms_for_user(state.user_id) - for e in events: - room_ids_to_states.setdefault(e.room_id, []).append(state) - - plist = yield self.store.get_presence_list_observers_accepted(state.user_id) - for u in plist: - users_to_states.setdefault(u, []).append(state) - - # Always notify self - users_to_states.setdefault(state.user_id, []).append(state) - - hosts_to_states = {} - if calculate_remote_hosts: - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue - - users = yield self.store.get_users_in_room(room_id) - hosts = set(get_domain_from_id(u) for u in users) - - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) - - for user_id, states in users_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue - - host = get_domain_from_id(user_id) - hosts_to_states.setdefault(host, []).extend(local_states) - - # TODO: de-dup hosts_to_states, as a single host might have multiple - # of same presence - - defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) - - @defer.inlineCallbacks def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ stream_id, max_token = yield self.store.update_presence(states) - parties = yield self._get_interested_parties(states) - room_ids_to_states, users_to_states, hosts_to_states = parties + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states.keys()] + users=[UserID.from_string(u) for u in users_to_states] ) - self._push_to_remotes(hosts_to_states) + self._push_to_remotes(states) @defer.inlineCallbacks def notify_for_states(self, state, stream_id): - parties = yield self._get_interested_parties([state]) - room_ids_to_states, users_to_states, hosts_to_states = parties + parties = yield get_interested_parties(self.store, [state]) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states.keys()] + users=[UserID.from_string(u) for u in users_to_states] ) - def _push_to_remotes(self, hosts_to_states): + def _push_to_remotes(self, states): """Sends state updates to remote servers. Args: - hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` + states (list(UserPresenceState)) """ - for host, states in hosts_to_states.items(): - self.federation.send_presence(host, states) + self.federation.send_presence(states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -719,9 +747,7 @@ class PresenceHandler(object): for state in updates ]) else: - defer.returnValue([ - format_user_presence_state(state, now) for state in updates - ]) + defer.returnValue(updates) @defer.inlineCallbacks def set_state(self, target_user, state, ignore_status_msg=False): @@ -766,18 +792,17 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - user_ids = yield self.store.get_users_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = set(get_domain_from_id(u) for u in user_ids) - self._push_to_remotes({host: (state,) for host in hosts}) + self._push_to_remotes([state]) else: + user_ids = yield self.store.get_users_in_room(room_id) user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) - self._push_to_remotes({user.domain: states.values()}) + self._push_to_remotes(states.values()) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): @@ -795,6 +820,9 @@ class PresenceHandler(object): as_event=False, ) + now = self.clock.time_msec() + results[:] = [format_user_presence_state(r, now) for r in results] + is_accepted = { row["observed_user_id"]: row["accepted"] for row in presence_list } @@ -847,6 +875,7 @@ class PresenceHandler(object): ) state_dict = yield self.get_state(observed_user, as_event=False) + state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) self.federation.send_edu( destination=observer_user.domain, @@ -910,11 +939,12 @@ class PresenceHandler(object): def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) - observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - - observer_room_ids = set(r.room_id for r in observer_rooms) - observed_room_ids = set(r.room_id for r in observed_rooms) + observer_room_ids = yield self.store.get_rooms_for_user( + observer_user.to_string() + ) + observed_room_ids = yield self.store.get_rooms_for_user( + observed_user.to_string() + ) if observer_room_ids & observed_room_ids: defer.returnValue(True) @@ -979,14 +1009,18 @@ def should_notify(old_state, new_state): return False -def format_user_presence_state(state, now): +def format_user_presence_state(state, now, include_user_id=True): """Convert UserPresenceState to a format that can be sent down to clients and to other servers. + + The "user_id" is optional so that this function can be used to format presence + updates for client /sync responses and for federation /send requests. """ content = { "presence": state.state, - "user_id": state.user_id, } + if include_user_id: + content["user_id"] = state.user_id if state.last_active_ts: content["last_active_ago"] = now - state.last_active_ts if state.status_msg and state.state != PresenceState.OFFLINE: @@ -1025,7 +1059,6 @@ class PresenceEventSource(object): # sending down the rare duplicate is not a concern. with Measure(self.clock, "presence.get_new_events"): - user_id = user.to_string() if from_key is not None: from_key = int(from_key) @@ -1034,18 +1067,7 @@ class PresenceEventSource(object): max_token = self.store.get_current_presence_token() - plist = yield self.store.get_presence_list_accepted(user.localpart) - users_interested_in = set(row["observed_user_id"] for row in plist) - users_interested_in.add(user_id) # So that we receive our own presence - - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id - ) - users_interested_in.update(users_who_share_room) - - if explicit_room_id: - user_ids = yield self.store.get_users_in_room(explicit_room_id) - users_interested_in.update(user_ids) + users_interested_in = yield self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1073,16 +1095,13 @@ class PresenceEventSource(object): updates = yield presence.current_state_for_users(user_ids_changed) - now = self.clock.time_msec() - - defer.returnValue(([ - { - "type": "m.presence", - "content": format_user_presence_state(s, now), - } - for s in updates.values() - if include_offline or s.state != PresenceState.OFFLINE - ], max_token)) + if include_offline: + defer.returnValue((updates.values(), max_token)) + else: + defer.returnValue(([ + s for s in updates.itervalues() + if s.state != PresenceState.OFFLINE + ], max_token)) def get_current_key(self): return self.store.get_current_presence_token() @@ -1090,6 +1109,31 @@ class PresenceEventSource(object): def get_pagination_rows(self, user, pagination_config, key): return self.get_new_events(user, from_key=None, include_offline=False) + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_interested_in(self, user, explicit_room_id, cache_context): + """Returns the set of users that the given user should see presence + updates for + """ + user_id = user.to_string() + plist = yield self.store.get_presence_list_accepted( + user.localpart, on_invalidate=cache_context.invalidate, + ) + users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in.add(user_id) # So that we receive our own presence + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id, on_invalidate=cache_context.invalidate, + ) + users_interested_in.update(users_who_share_room) + + if explicit_room_id: + user_ids = yield self.store.get_users_in_room( + explicit_room_id, on_invalidate=cache_context.invalidate, + ) + users_interested_in.update(user_ids) + + defer.returnValue(users_interested_in) + def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as @@ -1157,14 +1201,17 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline if user_id not in syncing_user_ids: - if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: + # If the user has done something recently but hasn't synced, + # don't set them as offline. + sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) + if now - sync_or_active > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, status_msg=None, ) changed = True else: - # We expect to be poked occaisonally by the other side. + # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: @@ -1255,3 +1302,66 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): persist_and_notify = True return new_state, persist_and_notify, federation_ping + + +@defer.inlineCallbacks +def get_interested_parties(store, states): + """Given a list of states return which entities (rooms, users) + are interested in the given states. + + Args: + states (list(UserPresenceState)) + + Returns: + 2-tuple: `(room_ids_to_states, users_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + room_ids = yield store.get_rooms_for_user(state.user_id) + for room_id in room_ids: + room_ids_to_states.setdefault(room_id, []).append(state) + + plist = yield store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + defer.returnValue((room_ids_to_states, users_to_states)) + + +@defer.inlineCallbacks +def get_interested_remotes(store, states, state_handler): + """Given a list of presence states figure out which remote servers + should be sent which. + + All the presence states should be for local users only. + + Args: + store (DataStore) + states (list(UserPresenceState)) + + Returns: + Deferred list of ([destinations], [UserPresenceState]), where for + each row the list of UserPresenceState should be sent to each + destination + """ + hosts_and_states = [] + + # First we look up the rooms each user is in (as well as any explicit + # subscriptions), then for each distinct room we look up the remote + # hosts in those rooms. + room_ids_to_states, users_to_states = yield get_interested_parties(store, states) + + for room_id, states in room_ids_to_states.iteritems(): + hosts = yield state_handler.get_current_hosts_in_room(room_id) + hosts_and_states.append((hosts, states)) + + for user_id, states in users_to_states.iteritems(): + host = get_domain_from_id(user_id) + hosts_and_states.append(([host], states)) + + defer.returnValue(hosts_and_states) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 87f74dfb8e..3465a787ab 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,25 +17,87 @@ import logging from twisted.internet import defer -import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id from ._base import BaseHandler - logger = logging.getLogger(__name__) class ProfileHandler(BaseHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): super(ProfileHandler, self).__init__(hs) - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "profile", self.on_profile_query ) + self.user_directory_handler = hs.get_user_directory_handler() + + if hs.config.worker_app is None: + self.clock.looping_call( + self._update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + + @defer.inlineCallbacks + def get_profile(self, user_id): + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + try: + result = yield self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + defer.returnValue(result) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get displayname") + + raise + + @defer.inlineCallbacks + def get_profile_from_cache(self, user_id): + """Get the profile information from our local cache. If the user is + ours then the profile information will always be corect. Otherwise, + it may be out of date/missing. + """ + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + profile = yield self.store.get_from_remote_profile_cache(user_id) + defer.returnValue(profile or {}) + @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -52,14 +114,15 @@ class ProfileHandler(BaseHandler): args={ "user_id": target_user.to_string(), "field": "displayname", - } + }, + ignore_backoff=True, ) except CodeMessageException as e: if e.code != 404: logger.exception("Failed to get displayname") raise - except: + except Exception: logger.exception("Failed to get displayname") else: defer.returnValue(result["displayname"]) @@ -81,7 +144,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -99,13 +168,14 @@ class ProfileHandler(BaseHandler): args={ "user_id": target_user.to_string(), "field": "avatar_url", - } + }, + ignore_backoff=True, ) except CodeMessageException as e: if e.code != 404: logger.exception("Failed to get avatar_url") raise - except: + except Exception: logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -124,7 +194,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -149,34 +225,71 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, requester): - user = requester.user - if not self.hs.is_mine(user): + def _update_join_states(self, requester, target_user): + if not self.hs.is_mine(target_user): return - self.ratelimit(requester) + yield self.ratelimit(requester) - joins = yield self.store.get_rooms_for_user( - user.to_string(), + room_ids = yield self.store.get_rooms_for_user( + target_user.to_string(), ) - for j in joins: - handler = self.hs.get_handlers().room_member_handler + for room_id in room_ids: + handler = self.hs.get_room_member_handler() try: - # Assume the user isn't a guest because we don't let guests set - # profile or avatar data. - # XXX why are we recreating `requester` here for each room? - # what was wrong with the `requester` we were passed? - requester = synapse.types.create_requester(user) + # Assume the target_user isn't a guest, + # because we don't let guests set profile or avatar data. yield handler.update_membership( requester, - user, - j.room_id, + target_user, + room_id, "join", # We treat a profile update like a join. ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: logger.warn( "Failed to update join event for room %s - %s", - j.room_id, str(e.message) + room_id, str(e.message) ) + + def _update_remote_profile_cache(self): + """Called periodically to check profiles of remote users we haven't + checked in a while. + """ + entries = yield self.store.get_remote_profile_cache_entries_that_expire( + last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS + ) + + for user_id, displayname, avatar_url in entries: + is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + user_id, + ) + if not is_subscribed: + yield self.store.maybe_delete_remote_profile_cache(user_id) + continue + + try: + profile = yield self.federation.make_query( + destination=get_domain_from_id(user_id), + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + except Exception: + logger.exception("Failed to get avatar_url") + + yield self.store.update_remote_profile_cache( + user_id, displayname, avatar_url + ) + continue + + new_name = profile.get("displayname") + new_avatar = profile.get("avatar_url") + + # We always hit update to update the last_check timestamp + yield self.store.update_remote_profile_cache( + user_id, new_name, new_avatar + ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py new file mode 100644 index 0000000000..5142ae153d --- /dev/null +++ b/synapse/handlers/read_marker.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from twisted.internet import defer + +from synapse.util.async import Linearizer + +import logging +logger = logging.getLogger(__name__) + + +class ReadMarkerHandler(BaseHandler): + def __init__(self, hs): + super(ReadMarkerHandler, self).__init__(hs) + self.server_name = hs.config.server_name + self.store = hs.get_datastore() + self.read_marker_linearizer = Linearizer(name="read_marker") + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def received_client_read_marker(self, room_id, user_id, event_id): + """Updates the read marker for a given user in a given room if the event ID given + is ahead in the stream relative to the current read marker. + + This uses a notifier to indicate that account data should be sent down /sync if + the read marker has changed. + """ + + with (yield self.read_marker_linearizer.queue((room_id, user_id))): + existing_read_marker = yield self.store.get_account_data_for_room_and_type( + user_id, room_id, "m.fully_read", + ) + + should_update = True + + if existing_read_marker: + # Only update if the new marker is ahead in the stream + should_update = yield self.store.is_event_after( + event_id, + existing_read_marker['event_id'] + ) + + if should_update: + content = { + "event_id": event_id + } + max_id = yield self.store.add_account_data_to_room( + user_id, room_id, "m.fully_read", content + ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 50aa513935..2e0672161c 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util import logcontext from ._base import BaseHandler @@ -34,7 +35,7 @@ class ReceiptsHandler(BaseHandler): self.store = hs.get_datastore() self.hs = hs self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() @@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler): is_new = yield self._handle_new_receipts([receipt]) if is_new: + # fire off a process in the background to send the receipt to + # remote servers self._push_remotes([receipt]) @defer.inlineCallbacks @@ -126,42 +129,46 @@ class ReceiptsHandler(BaseHandler): defer.returnValue(True) + @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks def _push_remotes(self, receipts): """Given a list of receipts, works out which remote servers should be poked and pokes them. """ - # TODO: Some of this stuff should be coallesced. - for receipt in receipts: - room_id = receipt["room_id"] - receipt_type = receipt["receipt_type"] - user_id = receipt["user_id"] - event_ids = receipt["event_ids"] - data = receipt["data"] - - users = yield self.state.get_current_user_in_room(room_id) - remotedomains = set(get_domain_from_id(u) for u in users) - remotedomains = remotedomains.copy() - remotedomains.discard(self.server_name) - - logger.debug("Sending receipt to: %r", remotedomains) - - for domain in remotedomains: - self.federation.send_edu( - destination=domain, - edu_type="m.receipt", - content={ - room_id: { - receipt_type: { - user_id: { - "event_ids": event_ids, - "data": data, + try: + # TODO: Some of this stuff should be coallesced. + for receipt in receipts: + room_id = receipt["room_id"] + receipt_type = receipt["receipt_type"] + user_id = receipt["user_id"] + event_ids = receipt["event_ids"] + data = receipt["data"] + + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) + + logger.debug("Sending receipt to: %r", remotedomains) + + for domain in remotedomains: + self.federation.send_edu( + destination=domain, + edu_type="m.receipt", + content={ + room_id: { + receipt_type: { + user_id: { + "event_ids": event_ids, + "data": data, + } } - } + }, }, - }, - key=(room_id, receipt_type, user_id), - ) + key=(room_id, receipt_type, user_id), + ) + except Exception: + logger.exception("Error pushing receipts to remote servers") @defer.inlineCallbacks def get_receipts_for_room(self, room_id, to_key): @@ -210,10 +217,9 @@ class ReceiptEventSource(object): else: from_key = None - rooms = yield self.store.get_rooms_for_user(user.to_string()) - rooms = [room.room_id for room in rooms] + room_ids = yield self.store.get_rooms_for_user(user.to_string()) events = yield self.store.get_linearized_receipts_for_rooms( - rooms, + room_ids, from_key=from_key, to_key=to_key, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 03c6a85fc6..f83c6b3cf8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,7 +15,6 @@ """Contains functions for registering clients.""" import logging -import urllib from twisted.internet import defer @@ -23,8 +22,10 @@ from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) from synapse.http.client import CaptchaServerHttpClient -from synapse.types import UserID -from synapse.util.async import run_on_reactor +from synapse import types +from synapse.types import UserID, create_requester, RoomID, RoomAlias +from synapse.util.async import run_on_reactor, Linearizer +from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -36,21 +37,33 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self.profile_handler = hs.get_profile_handler() + self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None self.macaroon_gen = hs.get_macaroon_generator() + self._generate_user_id_linearizer = Linearizer( + name="_generate_user_id_linearizer", + ) + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): - yield run_on_reactor() + if types.contains_invalid_mxid_characters(localpart): + raise SynapseError( + 400, + "User ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_USERNAME + ) - if urllib.quote(localpart.encode('utf-8')) != localpart: + if not localpart: raise SynapseError( 400, - "User ID can only contain characters a-z, 0-9, or '_-./'", + "User ID cannot be empty", Codes.INVALID_USERNAME ) @@ -73,7 +86,7 @@ class RegistrationHandler(BaseHandler): "A different user ID has already been registered for this session", ) - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: @@ -123,7 +136,7 @@ class RegistrationHandler(BaseHandler): yield run_on_reactor() password_hash = None if password: - password_hash = self.auth_handler().hash(password) + password_hash = yield self.auth_handler().hash(password) if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -158,6 +171,13 @@ class RegistrationHandler(BaseHandler): ), admin=admin, ) + + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + else: # autogen a sequential user ID attempts = 0 @@ -185,10 +205,17 @@ class RegistrationHandler(BaseHandler): token = None attempts += 1 + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = create_requester(user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding # rather than there being consistent matrix-wide ones, so we don't. - defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -246,11 +273,10 @@ class RegistrationHandler(BaseHandler): """ Registers email_id as SAML2 Based Auth. """ - if urllib.quote(localpart) != localpart: + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - "User ID must only contain characters which do not" - " require URL encoding." + "User ID can only contain characters a-z, 0-9, or '=_-./'", ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -279,12 +305,12 @@ class RegistrationHandler(BaseHandler): """ for c in threepidCreds: - logger.info("validating theeepidcred sid %s on id server %s", + logger.info("validating threepidcred sid %s on id server %s", c['sid'], c['idServer']) try: identity_handler = self.hs.get_handlers().identity_handler threepid = yield identity_handler.threepid_from_creds(c) - except: + except Exception: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") @@ -293,6 +319,11 @@ class RegistrationHandler(BaseHandler): logger.info("got threepid with medium '%s' and address '%s'", threepid['medium'], threepid['address']) + if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): + raise RegistrationError( + 403, "Third party identifier is not allowed" + ) + @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): """Links emails with a user ID and informs an identity server. @@ -325,9 +356,11 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _generate_user_id(self, reseed=False): if reseed or self._next_generated_user_id is None: - self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() - ) + with (yield self._generate_user_id_linearizer.queue(())): + if reseed or self._next_generated_user_id is None: + self._next_generated_user_id = ( + yield self.store.find_next_generated_user_id_localpart() + ) id = self._next_generated_user_id self._next_generated_user_id += 1 @@ -411,13 +444,12 @@ class RegistrationHandler(BaseHandler): create_profile_with_localpart=user.localpart, ) else: - yield self.store.user_delete_access_tokens(user_id=user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) - profile_handler = self.hs.get_handlers().profile_handler - yield profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, displayname, by_admin=True, ) @@ -427,16 +459,59 @@ class RegistrationHandler(BaseHandler): return self.hs.get_auth_handler() @defer.inlineCallbacks - def guest_access_token_for(self, medium, address, inviter_user_id): + def get_or_register_3pid_guest(self, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - defer.returnValue(access_token) + user_info = yield self.auth.get_user_by_access_token( + access_token + ) - _, access_token = yield self.register( + defer.returnValue((user_info["user"].to_string(), access_token)) + + user_id, access_token = yield self.register( generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id ) - defer.returnValue(access_token) + + defer.returnValue((user_id, access_token)) + + @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + room_member_handler = self.hs.get_room_member_handler() + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7e7671c9a2..8df8fcbbad 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,8 +61,14 @@ class RoomCreationHandler(BaseHandler): }, } + def __init__(self, hs): + super(RoomCreationHandler, self).__init__(hs) + + self.spam_checker = hs.get_spam_checker() + self.event_creation_handler = hs.get_event_creation_handler() + @defer.inlineCallbacks - def create_room(self, requester, config): + def create_room(self, requester, config, ratelimit=True): """ Creates a new room. Args: @@ -75,14 +82,18 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - self.ratelimit(requester) + if not self.spam_checker.user_may_create_room(user_id): + raise SynapseError(403, "You are not permitted to create rooms") + + if ratelimit: + yield self.ratelimit(requester) if "room_alias_name" in config: for wchar in string.whitespace: if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias.create( + room_alias = RoomAlias( config["room_alias_name"], self.hs.hostname, ) @@ -99,7 +110,7 @@ class RoomCreationHandler(BaseHandler): for i in invite_list: try: UserID.from_string(i) - except: + except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) invite_3pid_list = config.get("invite_3pid", []) @@ -114,7 +125,7 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( + gen_room_id = RoomID( random_string, self.hs.hostname, ) @@ -154,24 +165,23 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) - msg_handler = self.hs.get_handlers().message_handler - room_member_handler = self.hs.get_handlers().room_member_handler + room_member_handler = self.hs.get_room_member_handler() yield self._send_events_for_new_room( requester, room_id, - msg_handler, room_member_handler, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, creation_content=creation_content, room_alias=room_alias, + power_level_content_override=config.get("power_level_content_override", {}) ) if "name" in config: name = config["name"] - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -184,7 +194,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -195,12 +205,12 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) - content = {} - is_direct = config.get("is_direct", None) - if is_direct: - content["is_direct"] = is_direct - for invitee in invite_list: + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + yield room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -214,7 +224,7 @@ class RoomCreationHandler(BaseHandler): id_server = invite_3pid["id_server"] address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_handlers().room_member_handler.do_3pid_invite( + yield self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -239,13 +249,13 @@ class RoomCreationHandler(BaseHandler): self, creator, # A Requester object. room_id, - msg_handler, room_member_handler, preset_config, invite_list, initial_state, creation_content, - room_alias + room_alias, + power_level_content_override, ): def create(etype, content, **kwargs): e = { @@ -261,7 +271,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False @@ -291,7 +301,15 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - if (EventTypes.PowerLevels, '') not in initial_state: + # We treat the power levels override specially as this needs to be one + # of the first events that get sent into a room. + pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None) + if pl_content is not None: + yield send( + etype=EventTypes.PowerLevels, + content=pl_content, + ) + else: power_level_content = { "users": { creator_id: 100, @@ -316,6 +334,8 @@ class RoomCreationHandler(BaseHandler): for invitee in invite_list: power_level_content["users"][invitee] = 100 + power_level_content.update(power_level_content_override) + yield send( etype=EventTypes.PowerLevels, content=power_level_content, @@ -356,7 +376,7 @@ class RoomCreationHandler(BaseHandler): class RoomContextHandler(BaseHandler): @defer.inlineCallbacks - def get_event_context(self, user, room_id, event_id, limit, is_guest): + def get_event_context(self, user, room_id, event_id, limit): """Retrieves events, pagination tokens and state around a given event in a room. @@ -375,12 +395,15 @@ class RoomContextHandler(BaseHandler): now_token = yield self.hs.get_event_sources().get_current_token() + users = yield self.store.get_users_in_room(room_id) + is_peeking = user.to_string() not in users + def filter_evts(events): return filter_events_for_client( self.store, user.to_string(), events, - is_peeking=is_guest + is_peeking=is_peeking ) event = yield self.store.get_event(event_id, get_prev_content=True, @@ -452,12 +475,9 @@ class RoomEventSource(object): user.to_string() ) if app_service: - events, end_key = yield self.store.get_appservice_room_stream( - service=app_service, - from_key=from_key, - to_key=to_key, - limit=limit, - ) + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() else: room_events = yield self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 19eebbd43f..5757bb7f8a 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,12 +15,15 @@ from twisted.internet import defer +from six.moves import range + from ._base import BaseHandler from synapse.api.constants import ( EventTypes, JoinRules, ) from synapse.util.async import concurrently_execute +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache from synapse.types import ThirdPartyInstanceID @@ -42,8 +45,9 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache(hs) - self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + self.response_cache = ResponseCache(hs, "room_list") + self.remote_response_cache = ResponseCache(hs, "remote_room_list", + timeout_ms=30 * 1000) def get_local_public_room_list(self, limit=None, since_token=None, search_filter=None, @@ -62,23 +66,24 @@ class RoomListHandler(BaseHandler): appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. """ + logger.info( + "Getting public room list: limit=%r, since=%r, search=%r, network=%r", + limit, since_token, bool(search_filter), network_tuple, + ) if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. + logger.info("Bypassing cache as search request.") return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) key = (limit, since_token, network_tuple) - result = self.response_cache.get(key) - if not result: - result = self.response_cache.set( - key, - self._get_public_room_list( - limit, since_token, network_tuple=network_tuple - ) - ) - return result + return self.response_cache.wrap( + key, + self._get_public_room_list, + limit, since_token, network_tuple=network_tuple, + ) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, @@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler): rooms_to_order_value = {} rooms_to_num_joined = {} - rooms_to_latest_event_ids = {} newly_visible = [] newly_unpublished = [] @@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def get_order_for_room(room_id): - latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) - if not latest_event_ids: + # Most of the rooms won't have changed between the since token and + # now (especially if the since token is "now"). So, we can ask what + # the current users are in a room (that will hit a cache) and then + # check if the room has changed since the since token. (We have to + # do it in that order to avoid races). + # If things have changed then fall back to getting the current state + # at the since token. + joined_users = yield self.store.get_users_in_room(room_id) + if self.store.has_room_changed_since(room_id, stream_token): latest_event_ids = yield self.store.get_forward_extremeties_for_room( room_id, stream_token ) - rooms_to_latest_event_ids[room_id] = latest_event_ids - if not latest_event_ids: - return + if not latest_event_ids: + return + + joined_users = yield self.state_handler.get_current_user_in_room( + room_id, latest_event_ids, + ) - joined_users = yield self.state_handler.get_current_user_in_room( - room_id, latest_event_ids, - ) num_joined_users = len(joined_users) rooms_to_num_joined[room_id] = num_joined_users @@ -138,6 +149,8 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) + logger.info("Getting ordering for %i rooms since %s", + len(room_ids), stream_token) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -165,34 +178,43 @@ class RoomListHandler(BaseHandler): rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan.reverse() - # Actually generate the entries. _generate_room_entry will append to - # chunk but will stop if len(chunk) > limit - chunk = [] - if limit and not search_filter: + logger.info("After sorting and filtering, %i rooms remain", + len(rooms_to_scan)) + + # _append_room_entry_to_chunk will append to chunk but will stop if + # len(chunk) > limit + # + # Normally we will generate enough results on the first iteration here, + # but if there is a search filter, _append_room_entry_to_chunk may + # filter some results out, in which case we loop again. + # + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + # + # XXX if there is no limit, we may end up DoSing the server with + # calls to get_current_state_ids for every single room on the + # server. Surely we should cap this somehow? + # + if limit: step = limit + 1 - for i in xrange(0, len(rooms_to_scan), step): - # We iterate here because the vast majority of cases we'll stop - # at first iteration, but occaisonally _generate_room_entry - # won't append to the chunk and so we need to loop again. - # We don't want to scan over the entire range either as that - # would potentially waste a lot of work. - yield concurrently_execute( - lambda r: self._generate_room_entry( - r, rooms_to_num_joined[r], - chunk, limit, search_filter - ), - rooms_to_scan[i:i + step], 10 - ) - if len(chunk) >= limit + 1: - break else: + # step cannot be zero + step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 + + chunk = [] + for i in range(0, len(rooms_to_scan), step): + batch = rooms_to_scan[i:i + step] + logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( - lambda r: self._generate_room_entry( + lambda r: self._append_room_entry_to_chunk( r, rooms_to_num_joined[r], chunk, limit, search_filter ), - rooms_to_scan, 5 + batch, 5, ) + logger.info("Now %i rooms in result", len(chunk)) + if len(chunk) >= limit + 1: + break chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) @@ -256,21 +278,36 @@ class RoomListHandler(BaseHandler): defer.returnValue(results) @defer.inlineCallbacks - def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, - search_filter): + def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, + search_filter): + """Generate the entry for a room in the public room list and append it + to the `chunk` if it matches the search filter + """ if limit and len(chunk) > limit + 1: # We've already got enough, so lets just drop it. return + result = yield self.generate_room_entry(room_id, num_joined_users) + + if result and _matches_room_entry(result, search_filter): + chunk.append(result) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def generate_room_entry(self, room_id, num_joined_users, cache_context, + with_alias=True, allow_private=False): + """Returns the entry for a room + """ result = { "room_id": room_id, "num_joined_members": num_joined_users, } - current_state_ids = yield self.state_handler.get_current_state_ids(room_id) + current_state_ids = yield self.store.get_current_state_ids( + room_id, on_invalidate=cache_context.invalidate, + ) event_map = yield self.store.get_events([ - event_id for key, event_id in current_state_ids.items() + event_id for key, event_id in current_state_ids.iteritems() if key[0] in ( EventTypes.JoinRules, EventTypes.Name, @@ -291,12 +328,15 @@ class RoomListHandler(BaseHandler): join_rules_event = current_state.get((EventTypes.JoinRules, "")) if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) - if join_rule and join_rule != JoinRules.PUBLIC: + if not allow_private and join_rule and join_rule != JoinRules.PUBLIC: defer.returnValue(None) - aliases = yield self.store.get_aliases_for_room(room_id) - if aliases: - result["aliases"] = aliases + if with_alias: + aliases = yield self.store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) + if aliases: + result["aliases"] = aliases name_event = yield current_state.get((EventTypes.Name, "")) if name_event: @@ -334,8 +374,7 @@ class RoomListHandler(BaseHandler): if avatar_url: result["avatar_url"] = avatar_url - if _matches_room_entry(result, search_filter): - chunk.append(result) + defer.returnValue(result) @defer.inlineCallbacks def get_remote_public_room_list(self, server_name, limit=None, since_token=None, @@ -365,7 +404,7 @@ class RoomListHandler(BaseHandler): def _get_remote_list_cached(self, server_name, limit=None, since_token=None, search_filter=None, include_all_networks=False, third_party_instance_id=None,): - repl_layer = self.hs.get_replication_layer() + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( @@ -378,18 +417,14 @@ class RoomListHandler(BaseHandler): server_name, limit, since_token, include_all_networks, third_party_instance_id, ) - result = self.remote_response_cache.get(key) - if not result: - result = self.remote_response_cache.set( - key, - repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - ) - return result + return self.remote_response_cache.wrap( + key, + repl_layer.get_public_rooms, + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) class RoomListNextBatch(namedtuple("RoomListNextBatch", ( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2806555cf..714583f1d5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import abc import logging from signedjson.key import decode_verify_key_bytes @@ -29,47 +30,139 @@ from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room -from ._base import BaseHandler + logger = logging.getLogger(__name__) id_server_scheme = "https://" -class RoomMemberHandler(BaseHandler): +class RoomMemberHandler(object): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. + __metaclass__ = abc.ABCMeta + def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() + self.config = hs.config + self.simple_http_client = hs.get_simple_http_client() + + self.federation_handler = hs.get_handlers().federation_handler + self.directory_handler = hs.get_handlers().directory_handler + self.registration_handler = hs.get_handlers().registration_handler + self.profile_handler = hs.get_profile_handler() + self.event_creation_hander = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() + self.spam_checker = hs.get_spam_checker() - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") + @abc.abstractmethod + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Try and join a room that this server is not in + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers that can be used + to join via. + room_id (str): Room that we are trying to join + user (UserID): User who is trying to join + content (dict): A dict that should be used as the content of the + join event. + + Returns: + Deferred + """ + raise NotImplementedError() + + @abc.abstractmethod + def _remote_reject_invite(self, remote_room_hosts, room_id, target): + """Attempt to reject an invite for a room this server is not in. If we + fail to do so we locally mark the invite as rejected. + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers to use to try and + reject invite + room_id (str) + target (UserID): The user rejecting the invite + + Returns: + Deferred[dict]: A dictionary to be returned to the client, may + include event_id etc, or nothing if we locally rejected + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_joined_room(self, target, room_id): + """Notifies distributor on master process that the user has joined the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_left_room(self, target, room_id): + """Notifies distributor on master process that the user has left the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, - prev_event_ids, + prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, ): if content is None: content = {} - msg_handler = self.hs.get_handlers().message_handler content["membership"] = membership if requester.is_guest: content["kind"] = "guest" - event, context = yield msg_handler.create_event( + event, context = yield self.event_creation_hander.create_event( + requester, { "type": EventTypes.Member, "content": content, @@ -82,16 +175,18 @@ class RoomMemberHandler(BaseHandler): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. - duplicate = yield msg_handler.deduplicate_state_event(event, context) + duplicate = yield self.event_creation_hander.deduplicate_state_event( + event, context, + ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) - yield msg_handler.handle_new_client_event( + yield self.event_creation_hander.handle_new_client_event( requester, event, context, @@ -113,40 +208,16 @@ class RoomMemberHandler(BaseHandler): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target, room_id) + yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + yield self._user_left_room(target, room_id) defer.returnValue(event) @defer.inlineCallbacks - def remote_join(self, remote_room_hosts, room_id, user, content): - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield self.hs.get_handlers().federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, - ) - yield user_joined_room(self.distributor, user, room_id) - - def reject_remote_invite(self, user_id, room_id, remote_room_hosts): - return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id - ) - - @defer.inlineCallbacks def update_membership( self, requester, @@ -192,14 +263,19 @@ class RoomMemberHandler(BaseHandler): content_specified = bool(content) if content is None: content = {} + else: + # We do a copy here as we potentially change some keys + # later on. + content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" + # if this is a join with a 3pid signature, we may need to turn a 3pid + # invite into a normal invite before we can handle the join. if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( + yield self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -209,7 +285,41 @@ class RoomMemberHandler(BaseHandler): if not remote_room_hosts: remote_room_hosts = [] - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + if effective_membership_state not in ("leave", "ban",): + is_blocked = yield self.store.is_room_blocked(room_id) + if is_blocked: + raise SynapseError(403, "This room has been blocked on this server") + + if effective_membership_state == "invite": + block_invite = False + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + if self.config.block_non_admin_invites: + logger.info( + "Blocking invite: user is not admin and non-admin " + "invites disabled" + ) + block_invite = True + + if not self.spam_checker.user_may_invite( + requester.user.to_string(), target.to_string(), room_id, + ): + logger.info("Blocking invite due to spam checker") + block_invite = True + + if block_invite: + raise SynapseError( + 403, "Invites have been disabled on this server", + ) + + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -250,13 +360,13 @@ class RoomMemberHandler(BaseHandler): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) content["membership"] = Membership.JOIN - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler if not content_specified: content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) @@ -264,15 +374,15 @@ class RoomMemberHandler(BaseHandler): if requester.is_guest: content["kind"] = "guest" - ret = yield self.remote_join( - remote_room_hosts, room_id, target, content + ret = yield self._remote_join( + requester, remote_room_hosts, room_id, target, content ) defer.returnValue(ret) elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -286,20 +396,10 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - - try: - ret = yield self.reject_remote_invite( - target.to_string(), room_id, remote_room_hosts - ) - defer.returnValue(ret) - except SynapseError as e: - logger.warn("Failed to reject invite: %s", e) - - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) - - defer.returnValue({}) + res = yield self._remote_reject_invite( + requester, remote_room_hosts, room_id, target, + ) + defer.returnValue(res) res = yield self._local_membership_update( requester=requester, @@ -308,7 +408,7 @@ class RoomMemberHandler(BaseHandler): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_event_ids=latest_event_ids, + prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) @@ -354,8 +454,9 @@ class RoomMemberHandler(BaseHandler): else: requester = synapse.types.create_requester(target_user) - message_handler = self.hs.get_handlers().message_handler - prev_event = yield message_handler.deduplicate_state_event(event, context) + prev_event = yield self.event_creation_hander.deduplicate_state_event( + event, context, + ) if prev_event is not None: return @@ -367,7 +468,12 @@ class RoomMemberHandler(BaseHandler): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - yield message_handler.handle_new_client_event( + if event.membership not in (Membership.LEAVE, Membership.BAN): + is_blocked = yield self.store.is_room_blocked(room_id) + if is_blocked: + raise SynapseError(403, "This room has been blocked on this server") + + yield self.event_creation_hander.handle_new_client_event( requester, event, context, @@ -389,12 +495,12 @@ class RoomMemberHandler(BaseHandler): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target_user, room_id) + yield self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + yield self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -428,7 +534,7 @@ class RoomMemberHandler(BaseHandler): Raises: SynapseError if room alias could not be found. """ - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: @@ -440,7 +546,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, @@ -459,6 +565,16 @@ class RoomMemberHandler(BaseHandler): requester, txn_id ): + if self.config.block_non_admin_invites: + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + Codes.FORBIDDEN, + ) + invitee = yield self._lookup_3pid( id_server, medium, address ) @@ -496,7 +612,7 @@ class RoomMemberHandler(BaseHandler): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = yield self.hs.get_simple_http_client().get_json( + data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), { "medium": medium, @@ -507,7 +623,7 @@ class RoomMemberHandler(BaseHandler): if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) + yield self._verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: @@ -515,11 +631,11 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): + def _verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( + key_data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_scheme, server_hostname, key_name,), ) @@ -544,7 +660,7 @@ class RoomMemberHandler(BaseHandler): user, txn_id ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) + room_state = yield self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -575,6 +691,7 @@ class RoomMemberHandler(BaseHandler): token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( + requester=requester, id_server=id_server, medium=medium, address=address, @@ -589,8 +706,7 @@ class RoomMemberHandler(BaseHandler): ) ) - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_hander.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -612,6 +728,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( self, + requester, id_server, medium, address, @@ -628,6 +745,7 @@ class RoomMemberHandler(BaseHandler): Asks an identity server for a third party invite. Args: + requester (Requester) id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. @@ -669,24 +787,20 @@ class RoomMemberHandler(BaseHandler): "sender_avatar_url": inviter_avatar_url, } - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( + if self.config.invite_3pid_guest: + guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( + requester=requester, medium=medium, address=address, inviter_user_id=inviter_user_id, ) - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - invite_config.update({ "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), + "guest_user_id": guest_user_id, }) - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + data = yield self.simple_http_client.post_urlencoded_get_json( is_url, invite_config ) @@ -709,25 +823,6 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks - def forget(self, user, room_id): - user_id = user.to_string() - - member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership is not None and membership != Membership.LEAVE: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) - - if membership: - yield self.store.forget(user_id, room_id) - - @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): # Have we just created the room, and is this about to be the very # first member event? @@ -735,10 +830,11 @@ class RoomMemberHandler(BaseHandler): if len(current_state_ids) == 1 and create_event_id: defer.returnValue(self.hs.is_mine_id(create_event_id)) - for (etype, state_key), event_id in current_state_ids.items(): + for etype, state_key in current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue + event_id = current_state_ids[(etype, state_key)] event = yield self.store.get_event(event_id, allow_none=True) if not event: continue @@ -747,3 +843,102 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(True) defer.returnValue(False) + + +class RoomMemberMasterHandler(RoomMemberHandler): + def __init__(self, hs): + super(RoomMemberMasterHandler, self).__init__(hs) + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Implements RoomMemberHandler._remote_join + """ + # filter ourselves out of remote_room_hosts: do_invite_join ignores it + # and if it is the only entry we'd like to return a 404 rather than a + # 500. + + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield self._user_joined_room(user, room_id) + + @defer.inlineCallbacks + def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + """Implements RoomMemberHandler._remote_reject_invite + """ + fed_handler = self.federation_handler + try: + ret = yield fed_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + target.to_string(), + ) + defer.returnValue(ret) + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + defer.returnValue({}) + + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Implements RoomMemberHandler.get_or_register_3pid_guest + """ + rg = self.registration_handler + return rg.get_or_register_3pid_guest(medium, address, inviter_user_id) + + def _user_joined_room(self, target, room_id): + """Implements RoomMemberHandler._user_joined_room + """ + return user_joined_room(self.distributor, target, room_id) + + def _user_left_room(self, target, room_id): + """Implements RoomMemberHandler._user_left_room + """ + return user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership not in [ + Membership.LEAVE, Membership.BAN + ]: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py new file mode 100644 index 0000000000..493aec1e48 --- /dev/null +++ b/synapse/handlers/room_member_worker.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.handlers.room_member import RoomMemberHandler +from synapse.replication.http.membership import ( + remote_join, remote_reject_invite, get_or_register_3pid_guest, + notify_user_membership_change, +) + + +logger = logging.getLogger(__name__) + + +class RoomMemberWorkerHandler(RoomMemberHandler): + @defer.inlineCallbacks + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Implements RoomMemberHandler._remote_join + """ + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + ret = yield remote_join( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user_id=user.to_string(), + content=content, + ) + + yield self._user_joined_room(user, room_id) + + defer.returnValue(ret) + + def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + """Implements RoomMemberHandler._remote_reject_invite + """ + return remote_reject_invite( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user_id=target.to_string(), + ) + + def _user_joined_room(self, target, room_id): + """Implements RoomMemberHandler._user_joined_room + """ + return notify_user_membership_change( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + user_id=target.to_string(), + room_id=room_id, + change="joined", + ) + + def _user_left_room(self, target, room_id): + """Implements RoomMemberHandler._user_left_room + """ + return notify_user_membership_change( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + user_id=target.to_string(), + room_id=room_id, + change="left", + ) + + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Implements RoomMemberHandler.get_or_register_3pid_guest + """ + return get_or_register_3pid_guest( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index df75d70fac..9772ed1a0e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -61,7 +61,7 @@ class SearchHandler(BaseHandler): assert batch_group is not None assert batch_group_key is not None assert batch_token is not None - except: + except Exception: raise SynapseError(400, "Invalid batch") try: diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py new file mode 100644 index 0000000000..e057ae54c9 --- /dev/null +++ b/synapse/handlers/set_password.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, StoreError, SynapseError +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class SetPasswordHandler(BaseHandler): + """Handler which deals with changing user account passwords""" + def __init__(self, hs): + super(SetPasswordHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword, requester=None): + password_hash = yield self._auth_handler.hash(newpassword) + + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e + + # we want to log out all of the user's other sessions. First delete + # all his other devices. + yield self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id, + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + yield self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d7dcd1ce5b..b52e4c2aff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client +from synapse.types import RoomStreamToken from twisted.internet import defer @@ -51,6 +52,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ @@ -75,6 +77,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ @@ -94,6 +97,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ or self.state or self.account_data ) + __bool__ = __nonzero__ # python3 class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ @@ -105,6 +109,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 + + +class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ + "join", + "invite", + "leave", +])): + __slots__ = [] + + def __nonzero__(self): + return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 + + +class DeviceLists(collections.namedtuple("DeviceLists", [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track +])): + __slots__ = [] + + def __nonzero__(self): + return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 class SyncResult(collections.namedtuple("SyncResult", [ @@ -116,6 +144,9 @@ class SyncResult(collections.namedtuple("SyncResult", [ "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. "device_lists", # List of user_ids whose devices have chanegd + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device + "groups", ])): __slots__ = [] @@ -131,8 +162,10 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.archived or self.account_data or self.to_device or - self.device_lists + self.device_lists or + self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): @@ -143,7 +176,7 @@ class SyncHandler(object): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs) + self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, @@ -154,15 +187,11 @@ class SyncHandler(object): Returns: A Deferred SyncResult. """ - result = self.response_cache.get(sync_config.request_key) - if not result: - result = self.response_cache.set( - sync_config.request_key, - self._wait_for_sync_for_user( - sync_config, since_token, timeout, full_state - ) - ) - return result + return self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, + ) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, @@ -209,10 +238,10 @@ class SyncHandler(object): defer.returnValue(rules) @defer.inlineCallbacks - def ephemeral_by_room(self, sync_config, now_token, since_token=None): + def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: - sync_config (SyncConfig): The flags, filters and user for the sync. + sync_result_builder(SyncResultBuilder) now_token (StreamToken): Where the server is currently up to. since_token (StreamToken): Where the server was when the client last synced. @@ -222,11 +251,12 @@ class SyncHandler(object): typing events for that room. """ + sync_config = sync_result_builder.sync_config + with Measure(self.clock, "ephemeral_by_room"): typing_key = since_token.typing_key if since_token else "0" - rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) - room_ids = [room.room_id for room in rooms] + room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] typing, typing_key = yield typing_source.get_new_events( @@ -288,10 +318,20 @@ class SyncHandler(object): if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(current_state_ids.itervalues()) + recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), recents, + always_include_ids=current_state_ids, ) else: recents = [] @@ -323,10 +363,20 @@ class SyncHandler(object): loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in loaded_recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(current_state_ids.itervalues()) + loaded_recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), loaded_recents, + always_include_ids=current_state_ids, ) loaded_recents.extend(recents) recents = loaded_recents @@ -520,10 +570,22 @@ class SyncHandler(object): # Always use the `now_token` in `SyncResultBuilder` now_token = yield self.event_sources.get_current_token() + user_id = sync_config.user.to_string() + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() + else: + joined_room_ids = yield self.get_rooms_for_user_at( + user_id, now_token.room_stream_id, + ) + sync_result_builder = SyncResultBuilder( sync_config, full_state, since_token=since_token, now_token=now_token, + joined_room_ids=joined_room_ids, ) account_data_by_room = yield self._generate_sync_entry_for_account_data( @@ -533,7 +595,8 @@ class SyncHandler(object): res = yield self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_users = res + newly_joined_rooms, newly_joined_users, _, _ = res + _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( since_token is None and @@ -547,9 +610,22 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) device_lists = yield self._generate_sync_entry_for_device_list( - sync_result_builder + sync_result_builder, + newly_joined_rooms=newly_joined_rooms, + newly_joined_users=newly_joined_users, + newly_left_rooms=newly_left_rooms, + newly_left_users=newly_left_users, ) + device_id = sync_config.device_id + one_time_key_counts = {} + if device_id: + one_time_key_counts = yield self.store.count_e2e_one_time_keys( + user_id, device_id + ) + + yield self._generate_sync_entry_for_groups(sync_result_builder) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -558,31 +634,103 @@ class SyncHandler(object): archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_groups") + @defer.inlineCallbacks + def _generate_sync_entry_for_groups(self, sync_result_builder): + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + + if since_token and since_token.groups_key: + results = yield self.store.get_groups_changes_for_user( + user_id, since_token.groups_key, now_token.groups_key, + ) + else: + results = yield self.store.get_all_groups_for_user( + user_id, now_token.groups_key, + ) + + invited = {} + joined = {} + left = {} + for result in results: + membership = result["membership"] + group_id = result["group_id"] + gtype = result["type"] + content = result["content"] + + if membership == "join": + if gtype == "membership": + # TODO: Add profile + content.pop("membership", None) + joined[group_id] = content["content"] + else: + joined.setdefault(group_id, {})[gtype] = content + elif membership == "invite": + if gtype == "membership": + content.pop("membership", None) + invited[group_id] = content["content"] + else: + if gtype == "membership": + left[group_id] = content["content"] + + sync_result_builder.groups = GroupsSyncResult( + join=joined, + invite=invited, + leave=left, + ) + @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder): + def _generate_sync_entry_for_device_list(self, sync_result_builder, + newly_joined_rooms, newly_joined_users, + newly_left_rooms, newly_left_users): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token if since_token and since_token.device_list_key: - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = set(r.room_id for r in rooms) - - user_ids_changed = set() changed = yield self.store.get_user_whose_devices_changed( since_token.device_list_key ) - for other_user_id in changed: - other_rooms = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(e.room_id for e in other_rooms): - user_ids_changed.add(other_user_id) - defer.returnValue(user_ids_changed) + # TODO: Be more clever than this, i.e. remove users who we already + # share a room with? + for room_id in newly_joined_rooms: + joined_users = yield self.state.get_current_user_in_room(room_id) + newly_joined_users.update(joined_users) + + for room_id in newly_left_rooms: + left_users = yield self.state.get_current_user_in_room(room_id) + newly_left_users.update(left_users) + + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + changed.update(newly_joined_users) + + if not changed and not newly_left_users: + defer.returnValue(DeviceLists( + changed=[], + left=newly_left_users, + )) + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + defer.returnValue(DeviceLists( + changed=users_who_share_room & changed, + left=set(newly_left_users) - users_who_share_room, + )) else: - defer.returnValue([]) + defer.returnValue(DeviceLists( + changed=[], + left=[], + )) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -609,14 +757,14 @@ class SyncHandler(object): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.info("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug("Deleted %d to-device messages up to %d", + deleted, since_stream_id) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) - logger.info( + logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", len(messages), since_stream_id, stream_id, now_token.to_device_key ) @@ -721,14 +869,14 @@ class SyncHandler(object): extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) - states = yield self.presence_handler.get_states( - extra_users_ids, - as_event=True, - ) - presence.extend(states) + if extra_users_ids: + states = yield self.presence_handler.get_states( + extra_users_ids, + ) + presence.extend(states) - # Deduplicate the presence entries so that there's at most one per user - presence = {p["content"]["user_id"]: p for p in presence}.values() + # Deduplicate the presence entries so that there's at most one per user + presence = {p.user_id: p for p in presence}.values() presence = sync_config.filter_collection.filter_presence( presence @@ -746,8 +894,8 @@ class SyncHandler(object): account_data_by_room(dict): Dictionary of per room account data Returns: - Deferred(tuple): Returns a 2-tuple of - `(newly_joined_rooms, newly_joined_users)` + Deferred(tuple): Returns a 4-tuple of + `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)` """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( @@ -759,12 +907,27 @@ class SyncHandler(object): ephemeral_by_room = {} else: now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, + sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, ) sync_result_builder.now_token = now_token + # We check up front if anything has changed, if it hasn't then there is + # no point in going futher. + since_token = sync_result_builder.since_token + if not sync_result_builder.full_state: + if since_token and not ephemeral_by_room and not account_data_by_room: + have_changed = yield self._have_rooms_changed(sync_result_builder) + if not have_changed: + tags_by_room = yield self.store.get_updated_tags( + user_id, + since_token.account_data_key, + ) + if not tags_by_room: + logger.debug("no-oping sync") + defer.returnValue(([], [], [], [])) + ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, ) @@ -774,17 +937,17 @@ class SyncHandler(object): else: ignored_users = frozenset() - if sync_result_builder.since_token: + if since_token: res = yield self._get_rooms_changed(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms = res + room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( - user_id, - sync_result_builder.since_token.account_data_key, + user_id, since_token.account_data_key, ) else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res + newly_left_rooms = [] tags_by_room = yield self.store.get_tags_for_user(user_id) @@ -805,17 +968,55 @@ class SyncHandler(object): # Now we want to get any newly joined users newly_joined_users = set() - if sync_result_builder.since_token: + newly_left_users = set() + if since_token: for joined_sync in sync_result_builder.joined: it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() + joined_sync.timeline.events, joined_sync.state.itervalues() ) for event in it: if event.type == EventTypes.Member: if event.membership == Membership.JOIN: newly_joined_users.add(event.state_key) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_users + + defer.returnValue(( + newly_joined_rooms, + newly_joined_users, + newly_left_rooms, + newly_left_users, + )) + + @defer.inlineCallbacks + def _have_rooms_changed(self, sync_result_builder): + """Returns whether there may be any new events that should be sent down + the sync. Returns True if there are. + """ + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token - defer.returnValue((newly_joined_rooms, newly_joined_users)) + assert since_token + + # Get a list of membership change events that have happened. + rooms_changed = yield self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key + ) + + if rooms_changed: + defer.returnValue(True) + + stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream + for room_id in sync_result_builder.joined_room_ids: + if self.store.has_room_changed_since(room_id, stream_id): + defer.returnValue(True) + defer.returnValue(False) @defer.inlineCallbacks def _get_rooms_changed(self, sync_result_builder, ignored_users): @@ -836,14 +1037,6 @@ class SyncHandler(object): assert since_token - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - rooms = yield self.store.get_rooms_for_user(user_id) - joined_room_ids = set(r.room_id for r in rooms) - # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key @@ -854,16 +1047,29 @@ class SyncHandler(object): mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) newly_joined_rooms = [] + newly_left_rooms = [] room_entries = [] invited = [] - for room_id, events in mem_change_events_by_room_id.items(): + for room_id, events in mem_change_events_by_room_id.iteritems(): non_joins = [e for e in events if e.membership != Membership.JOIN] has_join = len(non_joins) != len(events) # We want to figure out if we joined the room at some point since # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: + + old_state_ids = None + if room_id in sync_result_builder.joined_room_ids and non_joins: + # Always include if the user (re)joined the room, especially + # important so that device list changes are calculated correctly. + # If there are non join member events, but we are still in the room, + # then the user must have left and joined + newly_joined_rooms.append(room_id) + + # User is in the room so we don't need to do the invite/leave checks + continue + + if room_id in sync_result_builder.joined_room_ids or has_join: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None @@ -874,12 +1080,33 @@ class SyncHandler(object): if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) - if room_id in joined_room_ids: - continue + # If user is in the room then we don't need to do the invite/leave checks + if room_id in sync_result_builder.joined_room_ids: + continue if not non_joins: continue + # Check if we have left the room. This can either be because we were + # joined before *or* that we since joined and then left. + if events[-1].membership != Membership.JOIN: + if has_join: + newly_left_rooms.append(room_id) + else: + if not old_state_ids: + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get( + (EventTypes.Member, user_id), + None, + ) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) + if old_mem_ev and old_mem_ev.membership == Membership.JOIN: + newly_left_rooms.append(room_id) + # Only bother if we're still currently invited should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: @@ -921,7 +1148,7 @@ class SyncHandler(object): # Get all events for rooms we're currently joined to. room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, + room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, limit=timeline_limit + 1, @@ -929,7 +1156,7 @@ class SyncHandler(object): # We loop through all room ids, even if there are no new events, in case # there are non room events taht we need to notify about. - for room_id in joined_room_ids: + for room_id in sync_result_builder.joined_room_ids: room_entry = room_to_events.get(room_id, None) if room_entry: @@ -957,7 +1184,7 @@ class SyncHandler(object): upto_token=since_token, )) - defer.returnValue((room_entries, invited, newly_joined_rooms)) + defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms)) @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builder, ignored_users): @@ -1137,6 +1364,54 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) + @defer.inlineCallbacks + def get_rooms_for_user_at(self, user_id, stream_ordering): + """Get set of joined rooms for a user at the given stream ordering. + + The stream ordering *must* be recent, otherwise this may throw an + exception if older than a month. (This function is called with the + current token, which should be perfectly fine). + + Args: + user_id (str) + stream_ordering (int) + + ReturnValue: + Deferred[frozenset[str]]: Set of room_ids the user is in at given + stream_ordering. + """ + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( + user_id, + ) + + joined_room_ids = set() + + # We need to check that the stream ordering of the join for each room + # is before the stream_ordering asked for. This might not be the case + # if the user joins a room between us getting the current token and + # calling `get_rooms_for_user_with_stream_ordering`. + # If the membership's stream ordering is after the given stream + # ordering, we need to go and work out if the user was in the room + # before. + for room_id, membership_stream_ordering in joined_rooms: + if membership_stream_ordering <= stream_ordering: + joined_room_ids.add(room_id) + continue + + logger.info("User joined room after current token: %s", room_id) + + extrems = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering, + ) + users_in_room = yield self.state.get_current_user_in_room( + room_id, extrems, + ) + if user_id in users_in_room: + joined_room_ids.add(room_id) + + joined_room_ids = frozenset(joined_room_ids) + defer.returnValue(joined_room_ids) + def _action_has_highlight(actions): for action in actions: @@ -1186,7 +1461,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): class SyncResultBuilder(object): "Used to help build up a new SyncResult for a user" - def __init__(self, sync_config, full_state, since_token, now_token): + def __init__(self, sync_config, full_state, since_token, now_token, + joined_room_ids): """ Args: sync_config(SyncConfig) @@ -1198,6 +1474,7 @@ class SyncResultBuilder(object): self.full_state = full_state self.since_token = since_token self.now_token = now_token + self.joined_room_ids = joined_room_ids self.presence = [] self.account_data = [] @@ -1205,6 +1482,8 @@ class SyncResultBuilder(object): self.invited = [] self.archived = [] self.device = [] + self.groups = None + self.to_device = [] class RoomSyncResultBuilder(object): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0eea7f8f9c..5d9736e88f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID, get_domain_from_id @@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id import logging from collections import namedtuple -import ujson as json logger = logging.getLogger(__name__) @@ -57,7 +56,7 @@ class TypingHandler(object): self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -90,7 +89,7 @@ class TypingHandler(object): until = self._member_typing_until.get(member, None) if not until or until <= now: logger.info("Timing out typing for: %s", member.user_id) - preserve_fn(self._stopped_typing)(member) + self._stopped_typing(member) continue # Check if we need to resend a keep alive over federation for this @@ -98,7 +97,8 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - preserve_fn(self._push_remote)( + run_in_background( + self._push_remote, member=member, typing=True ) @@ -148,7 +148,7 @@ class TypingHandler(object): # No point sending another notification defer.returnValue(None) - yield self._push_update( + self._push_update( member=member, typing=True, ) @@ -172,7 +172,7 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=target_user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -181,7 +181,6 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) - @defer.inlineCallbacks def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): # No point @@ -190,16 +189,15 @@ class TypingHandler(object): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - yield self._push_update( + self._push_update( member=member, typing=False, ) - @defer.inlineCallbacks def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - yield self._push_remote(member, typing) + run_in_background(self._push_remote, member, typing) self._push_update_local( member=member, @@ -208,28 +206,31 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_remote(self, member, typing): - users = yield self.state.get_current_user_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() + try: + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, - ) + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - for domain in set(get_domain_from_id(u) for u in users): - if domain != self.server_name: - self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") @defer.inlineCallbacks def _recv_edu(self, origin, content): @@ -288,11 +289,13 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps(list(typing), ensure_ascii=False) - rows.append((serial, room_id, typing_bytes)) + rows.append((serial, room_id, list(typing))) rows.sort() return rows + def get_current_token(self): + return self._latest_room_serial + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py new file mode 100644 index 0000000000..714f0195c8 --- /dev/null +++ b/synapse/handlers/user_directory.py @@ -0,0 +1,681 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.storage.roommember import ProfileInfo +from synapse.util.metrics import Measure +from synapse.util.async import sleep +from synapse.types import get_localpart_from_id + + +logger = logging.getLogger(__name__) + + +class UserDirectoryHandler(object): + """Handles querying of and keeping updated the user_directory. + + N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY + + The user directory is filled with users who this server can see are joined to a + world_readable or publically joinable room. We keep a database table up to date + by streaming changes of the current state and recalculating whether users should + be in the directory or not when necessary. + + For each user in the directory we also store a room_id which is public and that the + user is joined to. This allows us to ignore history_visibility and join_rules changes + for that user in all other public rooms, as we know they'll still be in at least + one public room. + """ + + INITIAL_ROOM_SLEEP_MS = 50 + INITIAL_ROOM_SLEEP_COUNT = 100 + INITIAL_ROOM_BATCH_SIZE = 100 + INITIAL_USER_SLEEP_MS = 10 + + def __init__(self, hs): + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.update_user_directory = hs.config.update_user_directory + self.search_all_users = hs.config.user_directory_search_all_users + + # When start up for the first time we need to populate the user_directory. + # This is a set of user_id's we've inserted already + self.initially_handled_users = set() + self.initially_handled_users_in_public = set() + + self.initially_handled_users_share = set() + self.initially_handled_users_share_private_room = set() + + # The current position in the current_state_delta stream + self.pos = None + + # Guard to ensure we only process deltas one at a time + self._is_processing = False + + if self.update_user_directory: + self.notifier.add_replication_callback(self.notify_new_event) + + # We kick this off so that we don't have to wait for a change before + # we start populating the user directory + self.clock.call_later(0, self.notify_new_event) + + def search_users(self, user_id, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + return self.store.search_user_dir(user_id, search_term, limit) + + @defer.inlineCallbacks + def notify_new_event(self): + """Called when there may be more deltas to process + """ + if not self.update_user_directory: + return + + if self._is_processing: + return + + self._is_processing = True + try: + yield self._unsafe_process() + finally: + self._is_processing = False + + @defer.inlineCallbacks + def handle_local_profile_change(self, user_id, profile): + """Called to update index of our local user profiles when they change + irrespective of any rooms the user may be in. + """ + yield self.store.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url, None, + ) + + @defer.inlineCallbacks + def _unsafe_process(self): + # If self.pos is None then means we haven't fetched it from DB + if self.pos is None: + self.pos = yield self.store.get_user_directory_stream_pos() + + # If still None then we need to do the initial fill of directory + if self.pos is None: + yield self._do_initial_spam() + self.pos = yield self.store.get_user_directory_stream_pos() + + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "user_dir_delta"): + deltas = yield self.store.get_current_state_deltas(self.pos) + if not deltas: + return + + logger.info("Handling %d state deltas", len(deltas)) + yield self._handle_deltas(deltas) + + self.pos = deltas[-1]["stream_id"] + yield self.store.update_user_directory_stream_pos(self.pos) + + @defer.inlineCallbacks + def _do_initial_spam(self): + """Populates the user_directory from the current state of the DB, used + when synapse first starts with user_directory support + """ + new_pos = yield self.store.get_max_stream_id_in_current_state_deltas() + + # Delete any existing entries just in case there are any + yield self.store.delete_all_from_user_dir() + + # We process by going through each existing room at a time. + room_ids = yield self.store.get_all_rooms() + + logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) + num_processed_rooms = 0 + + for room_id in room_ids: + logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) + yield self._handle_initial_room(room_id) + num_processed_rooms += 1 + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + + logger.info("Processed all rooms.") + + if self.search_all_users: + num_processed_users = 0 + user_ids = yield self.store.get_all_local_users() + logger.info("Doing initial update of user directory. %d users", len(user_ids)) + for user_id in user_ids: + # We add profiles for all users even if they don't match the + # include pattern, just in case we want to change it in future + logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) + yield self._handle_local_user(user_id) + num_processed_users += 1 + yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + + logger.info("Processed all users") + + self.initially_handled_users = None + self.initially_handled_users_in_public = None + self.initially_handled_users_share = None + self.initially_handled_users_share_private_room = None + + yield self.store.update_user_directory_stream_pos(new_pos) + + @defer.inlineCallbacks + def _handle_initial_room(self, room_id): + """Called when we initially fill out user_directory one room at a time + """ + is_in_room = yield self.store.is_host_joined(room_id, self.server_name) + if not is_in_room: + return + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + user_ids = set(users_with_profile) + unhandled_users = user_ids - self.initially_handled_users + + yield self.store.add_profiles_to_user_dir( + room_id, { + user_id: users_with_profile[user_id] for user_id in unhandled_users + } + ) + + self.initially_handled_users |= unhandled_users + + if is_public: + yield self.store.add_users_to_public_room( + room_id, + user_ids=user_ids - self.initially_handled_users_in_public + ) + self.initially_handled_users_in_public |= user_ids + + # We now go and figure out the new users who share rooms with user entries + # We sleep aggressively here as otherwise it can starve resources. + # We also batch up inserts/updates, but try to avoid too many at once. + to_insert = set() + to_update = set() + count = 0 + for user_id in user_ids: + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + + if not self.is_mine_id(user_id): + count += 1 + continue + + if self.store.get_if_app_services_interested_in_user(user_id): + count += 1 + continue + + for other_user_id in user_ids: + if user_id == other_user_id: + continue + + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) + count += 1 + + user_set = (user_id, other_user_id) + + if user_set in self.initially_handled_users_share_private_room: + continue + + if user_set in self.initially_handled_users_share: + if is_public: + continue + to_update.add(user_set) + else: + to_insert.add(user_set) + + if is_public: + self.initially_handled_users_share.add(user_set) + else: + self.initially_handled_users_share_private_room.add(user_set) + + if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE: + yield self.store.add_users_who_share_room( + room_id, not is_public, to_insert, + ) + to_insert.clear() + + if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE: + yield self.store.update_users_who_share_room( + room_id, not is_public, to_update, + ) + to_update.clear() + + if to_insert: + yield self.store.add_users_who_share_room( + room_id, not is_public, to_insert, + ) + to_insert.clear() + + if to_update: + yield self.store.update_users_who_share_room( + room_id, not is_public, to_update, + ) + to_update.clear() + + @defer.inlineCallbacks + def _handle_deltas(self, deltas): + """Called with the state deltas to process + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + # For join rule and visibility changes we need to check if the room + # may have become public or not and add/remove the users in said room + if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): + yield self._handle_room_publicity_change( + room_id, prev_event_id, event_id, typ, + ) + elif typ == EventTypes.Member: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="membership", + public_value=Membership.JOIN, + ) + + if change is None: + # Handle any profile changes + yield self._handle_profile_change( + state_key, room_id, prev_event_id, event_id, + ) + continue + + if not change: + # Need to check if the server left the room entirely, if so + # we might need to remove all the users in that room + is_in_room = yield self.store.is_host_joined( + room_id, self.server_name, + ) + if not is_in_room: + logger.info("Server left room: %r", room_id) + # Fetch all the users that we marked as being in user + # directory due to being in the room and then check if + # need to remove those users or not + user_ids = yield self.store.get_users_in_dir_due_to_room(room_id) + for user_id in user_ids: + yield self._handle_remove_user(room_id, user_id) + return + else: + logger.debug("Server is still in room: %r", room_id) + + if change: # The user joined + event = yield self.store.get_event(event_id, allow_none=True) + profile = ProfileInfo( + avatar_url=event.content.get("avatar_url"), + display_name=event.content.get("displayname"), + ) + + yield self._handle_new_user(room_id, state_key, profile) + else: # The user left + yield self._handle_remove_user(room_id, state_key) + else: + logger.debug("Ignoring irrelevant type: %r", typ) + + @defer.inlineCallbacks + def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + """Handle a room having potentially changed from/to world_readable/publically + joinable. + + Args: + room_id (str) + prev_event_id (str|None): The previous event before the state change + event_id (str|None): The new event after the state change + typ (str): Type of the event + """ + logger.debug("Handling change for %s: %s", typ, room_id) + + if typ == EventTypes.RoomHistoryVisibility: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="history_visibility", + public_value="world_readable", + ) + elif typ == EventTypes.JoinRules: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="join_rule", + public_value=JoinRules.PUBLIC, + ) + else: + raise Exception("Invalid event type") + # If change is None, no change. True => become world_readable/public, + # False => was world_readable/public + if change is None: + logger.debug("No change") + return + + # There's been a change to or from being world readable. + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + logger.debug("Change: %r, is_public: %r", change, is_public) + + if change and not is_public: + # If we became world readable but room isn't currently public then + # we ignore the change + return + elif not change and is_public: + # If we stopped being world readable but are still public, + # ignore the change + return + + if change: + users_with_profile = yield self.state.get_current_user_in_room(room_id) + for user_id, profile in users_with_profile.iteritems(): + yield self._handle_new_user(room_id, user_id, profile) + else: + users = yield self.store.get_users_in_public_due_to_room(room_id) + for user_id in users: + yield self._handle_remove_user(room_id, user_id) + + @defer.inlineCallbacks + def _handle_local_user(self, user_id): + """Adds a new local roomless user into the user_directory_search table. + Used to populate up the user index when we have an + user_directory_search_all_users specified. + """ + logger.debug("Adding new local user to dir, %r", user_id) + + profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(None, {user_id: profile}) + + @defer.inlineCallbacks + def _handle_new_user(self, room_id, user_id, profile): + """Called when we might need to add user to directory + + Args: + room_id (str): room_id that user joined or started being public + user_id (str) + """ + logger.debug("Adding new user to dir, %r", user_id) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile}) + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if is_public: + row = yield self.store.get_user_in_public_room(user_id) + if not row: + yield self.store.add_users_to_public_room(room_id, [user_id]) + else: + logger.debug("Not adding new user to public dir, %r", user_id) + + # Now we update users who share rooms with users. We do this by getting + # all the current users in the room and seeing which aren't already + # marked in the database as sharing with `user_id` + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + to_insert = set() + to_update = set() + + is_appservice = self.store.get_if_app_services_interested_in_user(user_id) + + # First, if they're our user then we need to update for every user + if self.is_mine_id(user_id) and not is_appservice: + # Returns a map of other_user_id -> shared_private. We only need + # to update mappings if for users that either don't share a room + # already (aren't in the map) or, if the room is private, those that + # only share a public room. + user_ids_shared = yield self.store.get_users_who_share_room_from_dir( + user_id + ) + + for other_user_id in users_with_profile: + if user_id == other_user_id: + continue + + shared_is_private = user_ids_shared.get(other_user_id) + if shared_is_private is True: + # We've already marked in the database they share a private room + continue + elif shared_is_private is False: + # They already share a public room, so only update if this is + # a private room + if not is_public: + to_update.add((user_id, other_user_id)) + elif shared_is_private is None: + # This is the first time they both share a room + to_insert.add((user_id, other_user_id)) + + # Next we need to update for every local user in the room + for other_user_id in users_with_profile: + if user_id == other_user_id: + continue + + is_appservice = self.store.get_if_app_services_interested_in_user( + other_user_id + ) + if self.is_mine_id(other_user_id) and not is_appservice: + shared_is_private = yield self.store.get_if_users_share_a_room( + other_user_id, user_id, + ) + if shared_is_private is True: + # We've already marked in the database they share a private room + continue + elif shared_is_private is False: + # They already share a public room, so only update if this is + # a private room + if not is_public: + to_update.add((other_user_id, user_id)) + elif shared_is_private is None: + # This is the first time they both share a room + to_insert.add((other_user_id, user_id)) + + if to_insert: + yield self.store.add_users_who_share_room( + room_id, not is_public, to_insert, + ) + + if to_update: + yield self.store.update_users_who_share_room( + room_id, not is_public, to_update, + ) + + @defer.inlineCallbacks + def _handle_remove_user(self, room_id, user_id): + """Called when we might need to remove user to directory + + Args: + room_id (str): room_id that user left or stopped being public that + user_id (str) + """ + logger.debug("Maybe removing user %r", user_id) + + row = yield self.store.get_user_in_directory(user_id) + update_user_dir = row and row["room_id"] == room_id + + row = yield self.store.get_user_in_public_room(user_id) + update_user_in_public = row and row["room_id"] == room_id + + if (update_user_in_public or update_user_dir): + # XXX: Make this faster? + rooms = yield self.store.get_rooms_for_user(user_id) + for j_room_id in rooms: + if (not update_user_in_public and not update_user_dir): + break + + is_in_room = yield self.store.is_host_joined( + j_room_id, self.server_name, + ) + + if not is_in_room: + continue + + if update_user_dir: + update_user_dir = False + yield self.store.update_user_in_user_dir(user_id, j_room_id) + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + j_room_id + ) + + if update_user_in_public and is_public: + yield self.store.update_user_in_public_user_list(user_id, j_room_id) + update_user_in_public = False + + if update_user_dir: + yield self.store.remove_from_user_dir(user_id) + elif update_user_in_public: + yield self.store.remove_from_user_in_public_room(user_id) + + # Now handle users_who_share_rooms. + + # Get a list of user tuples that were in the DB due to this room and + # users (this includes tuples where the other user matches `user_id`) + user_tuples = yield self.store.get_users_in_share_dir_with_room_id( + user_id, room_id, + ) + + for user_id, other_user_id in user_tuples: + # For each user tuple get a list of rooms that they still share, + # trying to find a private room, and update the entry in the DB + rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id) + + # If they dont share a room anymore, remove the mapping + if not rooms: + yield self.store.remove_user_who_share_room( + user_id, other_user_id, + ) + continue + + found_public_share = None + for j_room_id in rooms: + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + j_room_id + ) + + if is_public: + found_public_share = j_room_id + else: + found_public_share = None + yield self.store.update_users_who_share_room( + room_id, not is_public, [(user_id, other_user_id)], + ) + break + + if found_public_share: + yield self.store.update_users_who_share_room( + room_id, not is_public, [(user_id, other_user_id)], + ) + + @defer.inlineCallbacks + def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + """Check member event changes for any profile changes and update the + database if there are. + """ + if not prev_event_id or not event_id: + return + + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + event = yield self.store.get_event(event_id, allow_none=True) + + if not prev_event or not event: + return + + if event.membership != Membership.JOIN: + return + + prev_name = prev_event.content.get("displayname") + new_name = event.content.get("displayname") + + prev_avatar = prev_event.content.get("avatar_url") + new_avatar = event.content.get("avatar_url") + + if prev_name != new_name or prev_avatar != new_avatar: + yield self.store.update_profile_in_user_dir( + user_id, new_name, new_avatar, room_id, + ) + + @defer.inlineCallbacks + def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + """Given two events check if the `key_name` field in content changed + from not matching `public_value` to doing so. + + For example, check if `history_visibility` (`key_name`) changed from + `shared` to `world_readable` (`public_value`). + + Returns: + None if the field in the events either both match `public_value` + or if neither do, i.e. there has been no change. + True if it didnt match `public_value` but now does + False if it did match `public_value` but now doesn't + """ + prev_event = None + event = None + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + + if not event and not prev_event: + logger.debug("Neither event exists: %r %r", prev_event_id, event_id) + defer.returnValue(None) + + prev_value = None + value = None + + if prev_event: + prev_value = prev_event.content.get(key_name) + + if event: + value = event.content.get(key_name) + + logger.debug("prev_value: %r -> value: %r", prev_value, value) + + if value == public_value and prev_value != public_value: + defer.returnValue(True) + elif value != public_value and prev_value == public_value: + defer.returnValue(False) + else: + defer.returnValue(None) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index bfebb0f644..054372e179 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.api.errors import SynapseError + + +class RequestTimedOutError(SynapseError): + """Exception representing timeout of an outbound request""" + def __init__(self): + super(RequestTimedOutError, self).__init__(504, "Timed out") + + +def cancelled_to_request_timed_out_error(value, timeout): + """Turns CancelledErrors into RequestTimedOutErrors. + + For use with async.add_timeout_to_deferred + """ + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise RequestTimedOutError() + return value diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py new file mode 100644 index 0000000000..343e932cb1 --- /dev/null +++ b/synapse/http/additional_resource.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import wrap_request_handler +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class AdditionalResource(Resource): + """Resource wrapper for additional_resources + + If the user has configured additional_resources, we need to wrap the + handler class with a Resource so that we can map it into the resource tree. + + This class is also where we wrap the request handler with logging, metrics, + and exception handling. + """ + def __init__(self, hs, handler): + """Initialise AdditionalResource + + The ``handler`` should return a deferred which completes when it has + done handling the request. It should write a response with + ``request.write()``, and call ``request.finish()``. + + Args: + hs (synapse.server.HomeServer): homeserver + handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): + function to be called to handle the request. + """ + Resource.__init__(self) + self._handler = handler + + # these are required by the request_handler wrapper + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render(self, request): + self._async_render(request) + return NOT_DONE_YET + + @wrap_request_handler + def _async_render(self, request): + return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index ca2f770f5d..70a19d9b74 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +17,12 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( - CodeMessageException, SynapseError, Codes, + CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) -from synapse.util.logcontext import preserve_context_over_fn +from synapse.http import cancelled_to_request_timed_out_error +from synapse.util.async import add_timeout_to_deferred +from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.logcontext import make_deferred_yieldable import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -29,13 +33,14 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web.client import ( BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, readBody, PartialDownloadError, + HTTPConnectionPool, ) from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone -from StringIO import StringIO +from six import StringIO import simplejson as json import logging @@ -63,92 +68,139 @@ class SimpleHttpClient(object): """ def __init__(self, hs): self.hs = hs + + pool = HTTPConnectionPool(reactor) + + # the pusher makes lots of concurrent SSL connections to sygnal, and + # tends to do so in batches, so we need to allow the pool to keep lots + # of idle connections around. + pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + pool.cachedConnectionTimeout = 2 * 60 + # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent( reactor, connectTimeout=15, - contextFactory=hs.get_http_client_context_factory() + contextFactory=hs.get_http_client_context_factory(), + pool=pool, ) self.user_agent = hs.version_string + self.clock = hs.get_clock() if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + @defer.inlineCallbacks def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) - d = preserve_context_over_fn( - self.agent.request, - method, uri, *args, **kwargs - ) logger.info("Sending request %s %s", method, uri) - def _cb(response): + try: + request_deferred = self.agent.request( + method, uri, *args, **kwargs + ) + add_timeout_to_deferred( + request_deferred, + 60, cancelled_to_request_timed_out_error, + ) + response = yield make_deferred_yieldable(request_deferred) + incoming_responses_counter.inc(method, response.code) logger.info( "Received response to %s %s: %s", method, uri, response.code ) - return response - - def _eb(failure): + defer.returnValue(response) + except Exception as e: incoming_responses_counter.inc(method, "ERR") logger.info( "Error sending request to %s %s: %s %s", - method, uri, failure.type, failure.getErrorMessage() + method, uri, type(e).__name__, e.message ) - return failure + raise e - d.addCallbacks(_cb, _eb) + @defer.inlineCallbacks + def post_urlencoded_get_json(self, uri, args={}, headers=None): + """ + Args: + uri (str): + args (dict[str, str|List[str]]): query params + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header - return d + Returns: + Deferred[object]: parsed json + """ - @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + actual_headers = { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json): + def post_json_get_json(self, uri, post_json, headers=None): + """ + + Args: + uri (str): + post_json (object): + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ json_str = encode_canonical_json(post_json) logger.debug("HTTP POST %s -> %s", json_str, uri) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/json"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) + + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + raise self._exceptionFromFailedRequest(response, body) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, uri, args={}): + def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -157,6 +209,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -164,11 +218,14 @@ class SimpleHttpClient(object): On a non-2xx HTTP response. The response body will be used as the error message. """ - body = yield self.get_raw(uri, args) - defer.returnValue(json.loads(body)) + try: + body = yield self.get_raw(uri, args, headers=headers) + defer.returnValue(json.loads(body)) + except CodeMessageException as e: + raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): + def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -178,6 +235,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -190,17 +249,21 @@ class SimpleHttpClient(object): json_str = encode_canonical_json(json_body) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "PUT", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - "Content-Type": ["application/json"] - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -211,7 +274,7 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def get_raw(self, uri, args={}): + def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -220,6 +283,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. @@ -231,46 +296,65 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) else: raise CodeMessageException(response.code, body) + def _exceptionFromFailedRequest(self, response, body): + try: + jsonBody = json.loads(body) + errcode = jsonBody['errcode'] + error = jsonBody['error'] + return MatrixCodeMessageException(response.code, error, errcode) + except (ValueError, KeyError): + return CodeMessageException(response.code, body) + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None): + def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", url.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - headers = dict(response.headers.getAllRawHeaders()) + resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in headers and headers['Content-Length'] > max_size: + if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -291,10 +375,9 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) + length = yield make_deferred_yieldable(_readBodyToFile( + response, output_stream, max_size, + )) except Exception as e: logger.exception("Failed to download body") raise SynapseError( @@ -303,7 +386,9 @@ class SimpleHttpClient(object): Codes.UNKNOWN, ) - defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + defer.returnValue( + (length, resp_headers, response.request.absoluteURI, response.code), + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. @@ -371,7 +456,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): ) try: - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(body) except PartialDownloadError as e: # twisted dislikes google's response, no content length. @@ -422,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient): reactor, SpiderEndpointFactory(hs) ) - ), [('gzip', GzipDecoder)] + ), [(b'gzip', GzipDecoder)] ) # We could look like Chrome: # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 564ae4c10d..87a482650d 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor from twisted.internet.error import ConnectError @@ -30,7 +29,10 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} - +# our record of an individual server which can be tried to reach a destination. +# +# "host" is the hostname acquired from the SRV record. Except when there's +# no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" ) @@ -224,9 +226,10 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s" % self.service_name + "No server available for %s" % self.service_name ) + # look for all servers with the same priority min_priority = self.servers[0].priority weight_indexes = list( (index, server.weight + 1) @@ -236,11 +239,22 @@ class SRVClientEndpoint(object): total_weight = sum(weight for index, weight in weight_indexes) target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: target_weight -= weight if target_weight <= 0: server = self.servers[index] + # XXX: this looks totally dubious: + # + # (a) we never reuse a server until we have been through + # all of the servers at the same priority, so if the + # weights are A: 100, B:1, we always do ABABAB instead of + # AAAA...AAAB (approximately). + # + # (b) After using all the servers at the lowest priority, + # we move onto the next priority. We should only use the + # second priority if servers at the top priority are + # unreachable. + # del self.servers[index] self.used_servers.append(server) return server @@ -277,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t if (len(answers) == 1 and answers[0].type == dns.SRV and answers[0].payload - and answers[0].payload.target == dns.Name('.')): + and answers[0].payload.target == dns.Name(b'.')): raise ConnectError("Service %s unavailable" % service_name) for answer in answers: @@ -285,26 +299,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t continue payload = answer.payload - host = str(payload.target) - srv_ttl = answer.ttl - - try: - answers, _, _ = yield dns_client.lookupAddress(host) - except DNSNameError: - continue - for answer in answers: - if answer.type == dns.A and answer.payload: - ip = answer.payload.dottedQuad() - host_ttl = min(srv_ttl, answer.ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 78b92cef36..4b2b85464d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,23 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from twisted.internet import defer, reactor, protocol from twisted.internet.error import DNSLookupError from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone +from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.async import sleep -from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics +from synapse.util.async import sleep, add_timeout_to_deferred +from synapse.util import logcontext +from synapse.util.logcontext import make_deferred_yieldable +import synapse.util.retryutils from canonicaljson import encode_canonical_json from synapse.api.errors import ( - SynapseError, Codes, HttpResponseException, + SynapseError, Codes, HttpResponseException, FederationDeniedError, ) from signedjson.sign import sign_json @@ -39,8 +41,7 @@ import logging import random import sys import urllib -import urlparse - +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) outbound_logger = logging.getLogger("synapse.http.outbound") @@ -94,6 +95,7 @@ class MatrixFederationHttpClient(object): reactor, MatrixFederationEndpointFactory(hs), pool=pool ) self.clock = hs.get_clock() + self._store = hs.get_datastore() self.version_string = hs.version_string self._next_id = 1 @@ -103,123 +105,161 @@ class MatrixFederationHttpClient(object): ) @defer.inlineCallbacks - def _create_request(self, destination, method, path_bytes, - body_callback, headers_dict={}, param_bytes=b"", - query_bytes=b"", retry_on_dns_fail=True, - timeout=None, long_retries=False): - """ Creates and sends a request to the given url - """ - headers_dict[b"User-Agent"] = [self.version_string] - headers_dict[b"Host"] = [destination] + def _request(self, destination, method, path, + body_callback, headers_dict={}, param_bytes=b"", + query_bytes=b"", retry_on_dns_fail=True, + timeout=None, long_retries=False, + ignore_backoff=False, + backoff_on_404=False): + """ Creates and sends a request to the given server + Args: + destination (str): The remote server to send the HTTP request to. + method (str): HTTP method + path (str): The HTTP path + ignore_backoff (bool): true to ignore the historical backoff data + and try the request anyway. + backoff_on_404 (bool): Back off if we get a 404 - url_bytes = self._create_url( - destination, path_bytes, param_bytes, query_bytes - ) + Returns: + Deferred: resolves with the http response object on success. - txn_id = "%s-O-%s" % (method, self._next_id) - self._next_id = (self._next_id + 1) % (sys.maxint - 1) + Fails with ``HTTPRequestException``: if we get an HTTP response + code >= 300. - outbound_logger.info( - "{%s} [%s] Sending request: %s %s", - txn_id, destination, method, url_bytes - ) + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. - # XXX: Would be much nicer to retry only at the transaction-layer - # (once we have reliable transactions in place) - if long_retries: - retries_left = MAX_LONG_RETRIES - else: - retries_left = MAX_SHORT_RETRIES + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist - http_url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "") + (May also fail with plenty of other Exceptions for things like DNS + failures, connection failures, SSL failures.) + """ + if ( + self.hs.config.federation_domain_whitelist and + destination not in self.hs.config.federation_domain_whitelist + ): + raise FederationDeniedError(destination) + + limiter = yield synapse.util.retryutils.get_retry_limiter( + destination, + self.clock, + self._store, + backoff_on_404=backoff_on_404, + ignore_backoff=ignore_backoff, ) - log_result = None - try: - while True: - producer = None - if body_callback: - producer = body_callback(method, http_url_bytes, headers_dict) - - try: - def send_request(): - request_deferred = preserve_context_over_fn( - self.agent.request, + destination = destination.encode("ascii") + path_bytes = path.encode("ascii") + with limiter: + headers_dict[b"User-Agent"] = [self.version_string] + headers_dict[b"Host"] = [destination] + + url_bytes = self._create_url( + destination, path_bytes, param_bytes, query_bytes + ) + + txn_id = "%s-O-%s" % (method, self._next_id) + self._next_id = (self._next_id + 1) % (sys.maxint - 1) + + outbound_logger.info( + "{%s} [%s] Sending request: %s %s", + txn_id, destination, method, url_bytes + ) + + # XXX: Would be much nicer to retry only at the transaction-layer + # (once we have reliable transactions in place) + if long_retries: + retries_left = MAX_LONG_RETRIES + else: + retries_left = MAX_SHORT_RETRIES + + http_url_bytes = urlparse.urlunparse( + ("", "", path_bytes, param_bytes, query_bytes, "") + ) + + log_result = None + try: + while True: + producer = None + if body_callback: + producer = body_callback(method, http_url_bytes, headers_dict) + + try: + request_deferred = self.agent.request( method, url_bytes, Headers(headers_dict), producer ) - - return self.clock.time_bound_deferred( + add_timeout_to_deferred( + request_deferred, + timeout / 1000. if timeout else 60, + cancelled_to_request_timed_out_error, + ) + response = yield make_deferred_yieldable( request_deferred, - time_out=timeout / 1000. if timeout else 60, ) - response = yield preserve_context_over_fn(send_request) + log_result = "%d %s" % (response.code, response.phrase,) + break + except Exception as e: + if not retry_on_dns_fail and isinstance(e, DNSLookupError): + logger.warn( + "DNS Lookup failed to %s with %s", + destination, + e + ) + log_result = "DNS Lookup failed to %s with %s" % ( + destination, e + ) + raise - log_result = "%d %s" % (response.code, response.phrase,) - break - except Exception as e: - if not retry_on_dns_fail and isinstance(e, DNSLookupError): logger.warn( - "DNS Lookup failed to %s with %s", + "{%s} Sending request failed to %s: %s %s: %s", + txn_id, destination, - e - ) - log_result = "DNS Lookup failed to %s with %s" % ( - destination, e + method, + url_bytes, + _flatten_response_never_received(e), ) - raise - - logger.warn( - "{%s} Sending request failed to %s: %s %s: %s - %s", - txn_id, - destination, - method, - url_bytes, - type(e).__name__, - _flatten_response_never_received(e), - ) - - log_result = "%s - %s" % ( - type(e).__name__, _flatten_response_never_received(e), - ) - - if retries_left and not timeout: - if long_retries: - delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) - delay = min(delay, 60) - delay *= random.uniform(0.8, 1.4) + + log_result = _flatten_response_never_received(e) + + if retries_left and not timeout: + if long_retries: + delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) + delay = min(delay, 60) + delay *= random.uniform(0.8, 1.4) + else: + delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) + delay = min(delay, 2) + delay *= random.uniform(0.8, 1.4) + + yield sleep(delay) + retries_left -= 1 else: - delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) - delay = min(delay, 2) - delay *= random.uniform(0.8, 1.4) - - yield sleep(delay) - retries_left -= 1 - else: - raise - finally: - outbound_logger.info( - "{%s} [%s] Result: %s", - txn_id, - destination, - log_result, - ) + raise + finally: + outbound_logger.info( + "{%s} [%s] Result: %s", + txn_id, + destination, + log_result, + ) - if 200 <= response.code < 300: - pass - else: - # :'( - # Update transactions table? - body = yield preserve_context_over_fn(readBody, response) - raise HttpResponseException( - response.code, response.phrase, body - ) + if 200 <= response.code < 300: + pass + else: + # :'( + # Update transactions table? + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) + raise HttpResponseException( + response.code, response.phrase, body + ) - defer.returnValue(response) + defer.returnValue(response) def sign_request(self, destination, method, url_bytes, headers_dict, content=None): @@ -247,14 +287,18 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None, - long_retries=False, timeout=None): + def put_json(self, destination, path, args={}, data={}, + json_data_callback=None, + long_retries=False, timeout=None, + ignore_backoff=False, + backoff_on_404=False): """ Sends the specifed json data using PUT Args: destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. + args (dict): query params data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to @@ -263,11 +307,24 @@ class MatrixFederationHttpClient(object): retry for a short or long time. timeout(int): How long to try (in ms) the destination for before giving up. None indicates no timeout. + ignore_backoff (bool): true to ignore the historical backoff data + and try the request anyway. + backoff_on_404 (bool): True if we should count a 404 response as + a failure of the server (and should therefore back off future + requests) Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. On a 4xx or 5xx error response a - CodeMessageException is raised. + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ if not json_data_callback: @@ -282,26 +339,30 @@ class MatrixFederationHttpClient(object): producer = _JsonProducer(json_data) return producer - response = yield self._create_request( - destination.encode("ascii"), + response = yield self._request( + destination, "PUT", - path.encode("ascii"), + path, body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + query_bytes=encode_query_args(args), long_retries=long_retries, timeout=timeout, + ignore_backoff=ignore_backoff, + backoff_on_404=backoff_on_404, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? check_content_type_is_json(response.headers) - body = yield preserve_context_over_fn(readBody, response) + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) defer.returnValue(json.loads(body)) @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, - timeout=None): + timeout=None, ignore_backoff=False, args={}): """ Sends the specifed json data using POST Args: @@ -314,11 +375,21 @@ class MatrixFederationHttpClient(object): retry for a short or long time. timeout(int): How long to try (in ms) the destination for before giving up. None indicates no timeout. - + ignore_backoff (bool): true to ignore the historical backoff data and + try the request anyway. + args (dict): query params Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. On a 4xx or 5xx error response a - CodeMessageException is raised. + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ def body_callback(method, url_bytes, headers_dict): @@ -327,27 +398,30 @@ class MatrixFederationHttpClient(object): ) return _JsonProducer(data) - response = yield self._create_request( - destination.encode("ascii"), + response = yield self._request( + destination, "POST", - path.encode("ascii"), + path, + query_bytes=encode_query_args(args), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, timeout=timeout, + ignore_backoff=ignore_backoff, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? check_content_type_is_json(response.headers) - body = yield preserve_context_over_fn(readBody, response) + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) defer.returnValue(json.loads(body)) @defer.inlineCallbacks def get_json(self, destination, path, args={}, retry_on_dns_fail=True, - timeout=None): + timeout=None, ignore_backoff=False): """ GETs some json from the given host homeserver and path Args: @@ -359,57 +433,122 @@ class MatrixFederationHttpClient(object): timeout (int): How long to try (in ms) the destination for before giving up. None indicates no timeout and that the request will be retried. + ignore_backoff (bool): true to ignore the historical backoff data + and try the request anyway. Returns: - Deferred: Succeeds when we get *any* HTTP response. + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ logger.debug("get_json args: %s", args) - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, basestring): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) def body_callback(method, url_bytes, headers_dict): self.sign_request(destination, method, url_bytes, headers_dict) return None - response = yield self._create_request( - destination.encode("ascii"), + response = yield self._request( + destination, "GET", - path.encode("ascii"), - query_bytes=query_bytes, + path, + query_bytes=encode_query_args(args), body_callback=body_callback, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, + ignore_backoff=ignore_backoff, ) if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? check_content_type_is_json(response.headers) - body = yield preserve_context_over_fn(readBody, response) + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks + def delete_json(self, destination, path, long_retries=False, + timeout=None, ignore_backoff=False, args={}): + """Send a DELETE request to the remote expecting some json response + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. + ignore_backoff (bool): true to ignore the historical backoff data and + try the request anyway. + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist + """ + + response = yield self._request( + destination, + "DELETE", + path, + query_bytes=encode_query_args(args), + headers_dict={"Content-Type": ["application/json"]}, + long_retries=long_retries, + timeout=timeout, + ignore_backoff=ignore_backoff, + ) + + if 200 <= response.code < 300: + # We need to update the transactions table to say it was sent? + check_content_type_is_json(response.headers) + + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) defer.returnValue(json.loads(body)) @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, - retry_on_dns_fail=True, max_size=None): + retry_on_dns_fail=True, max_size=None, + ignore_backoff=False): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. path (str): The HTTP path to GET. output_stream (file): File to write the response body to. args (dict): Optional dictionary used to create the query string. + ignore_backoff (bool): true to ignore the historical backoff data + and try the request anyway. Returns: - A (int,dict) tuple of the file length and a dict of the response - headers. + Deferred: resolves with an (int,dict) tuple of the file length and + a dict of the response headers. + + Fails with ``HTTPRequestException`` if we get an HTTP response code + >= 300 + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ encoded_args = {} @@ -419,29 +558,30 @@ class MatrixFederationHttpClient(object): encoded_args[k] = [v.encode("UTF-8") for v in vs] query_bytes = urllib.urlencode(encoded_args, True) - logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) + logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) def body_callback(method, url_bytes, headers_dict): self.sign_request(destination, method, url_bytes, headers_dict) return None - response = yield self._create_request( - destination.encode("ascii"), + response = yield self._request( + destination, "GET", - path.encode("ascii"), + path, query_bytes=query_bytes, body_callback=body_callback, - retry_on_dns_fail=retry_on_dns_fail + retry_on_dns_fail=retry_on_dns_fail, + ignore_backoff=ignore_backoff, ) headers = dict(response.headers.getAllRawHeaders()) try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) - except: + with logcontext.PreserveLoggingContext(): + length = yield _readBodyToFile( + response, output_stream, max_size + ) + except Exception: logger.exception("Failed to download body") raise @@ -506,12 +646,14 @@ class _JsonProducer(object): def _flatten_response_never_received(e): if hasattr(e, "reasons"): - return ", ".join( + reasons = ", ".join( _flatten_response_never_received(f.value) for f in e.reasons ) + + return "%s:[%s]" % (type(e).__name__, reasons) else: - return "%s: %s" % (type(e).__name__, e.message,) + return repr(e) def check_content_type_is_json(headers): @@ -538,3 +680,15 @@ def check_content_type_is_json(headers): raise RuntimeError( "Content-Type not application/json: was '%s'" % c_type ) + + +def encode_query_args(args): + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, basestring): + vs = [vs] + encoded_args[k] = [v.encode("UTF-8") for v in vs] + + query_bytes = urllib.urlencode(encoded_args, True) + + return query_bytes diff --git a/synapse/http/server.py b/synapse/http/server.py index 14715878c5..55b9ad5251 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,6 +29,7 @@ from canonicaljson import ( ) from twisted.internet import defer +from twisted.python import failure from twisted.web import server, resource from twisted.web.server import NOT_DONE_YET from twisted.web.util import redirectTo @@ -35,42 +37,86 @@ from twisted.web.util import redirectTo import collections import logging import urllib -import ujson +import simplejson logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -incoming_requests_counter = metrics.register_counter( - "requests", +# total number of responses served, split by method/servlet/tag +response_count = metrics.register_counter( + "response_count", labels=["method", "servlet", "tag"], + alternative_names=( + # the following are all deprecated aliases for the same metric + metrics.name_prefix + x for x in ( + "_requests", + "_response_time:count", + "_response_ru_utime:count", + "_response_ru_stime:count", + "_response_db_txn_count:count", + "_response_db_txn_duration:count", + ) + ) +) + +requests_counter = metrics.register_counter( + "requests_received", + labels=["method", "servlet", ], ) + outgoing_responses_counter = metrics.register_counter( "responses", labels=["method", "code"], ) -response_timer = metrics.register_distribution( - "response_time", - labels=["method", "servlet", "tag"] +response_timer = metrics.register_counter( + "response_time_seconds", + labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_time:total", + ), ) -response_ru_utime = metrics.register_distribution( - "response_ru_utime", labels=["method", "servlet", "tag"] +response_ru_utime = metrics.register_counter( + "response_ru_utime_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_ru_utime:total", + ), ) -response_ru_stime = metrics.register_distribution( - "response_ru_stime", labels=["method", "servlet", "tag"] +response_ru_stime = metrics.register_counter( + "response_ru_stime_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_ru_stime:total", + ), ) -response_db_txn_count = metrics.register_distribution( - "response_db_txn_count", labels=["method", "servlet", "tag"] +response_db_txn_count = metrics.register_counter( + "response_db_txn_count", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_db_txn_count:total", + ), ) -response_db_txn_duration = metrics.register_distribution( - "response_db_txn_duration", labels=["method", "servlet", "tag"] +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request +response_db_txn_duration = metrics.register_counter( + "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"], + alternative_names=( + metrics.name_prefix + "_response_db_txn_duration:total", + ), ) +# seconds spent waiting for a db connection, when processing this request +response_db_sched_duration = metrics.register_counter( + "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] +) + +# size in bytes of the response written +response_size = metrics.register_counter( + "response_size", labels=["method", "servlet", "tag"] +) _next_request_id = 0 @@ -106,7 +152,12 @@ def wrap_request_handler(request_handler, include_metrics=False): with LoggingContext(request_id) as request_context: with Measure(self.clock, "wrapped_request_handler"): request_metrics = RequestMetrics() - request_metrics.start(self.clock, name=self.__class__.__name__) + # we start the request metrics timer here with an initial stab + # at the servlet name. For most requests that name will be + # JsonResource (or a subclass), and JsonResource._async_render + # will update it once it picks a servlet. + servlet_name = self.__class__.__name__ + request_metrics.start(self.clock, name=servlet_name) request_context.request = request_id with request.processing(): @@ -115,6 +166,7 @@ def wrap_request_handler(request_handler, include_metrics=False): if include_metrics: yield request_handler(self, request, request_metrics) else: + requests_counter.inc(request.method, servlet_name) yield request_handler(self, request) except CodeMessageException as e: code = e.code @@ -130,13 +182,18 @@ def wrap_request_handler(request_handler, include_metrics=False): pretty_print=_request_user_agent_is_curl(request), version_string=self.version_string, ) - except: - logger.exception( - "Failed handle request %s.%s on %r: %r", + except Exception: + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + logger.error( + "Failed handle request %s.%s on %r: %r: %s", request_handler.__module__, request_handler.__name__, self, - request + request, + f.getTraceback().rstrip(), ) respond_with_json( request, @@ -145,7 +202,9 @@ def wrap_request_handler(request_handler, include_metrics=False): "error": "Internal server error", "errcode": Codes.UNKNOWN, }, - send_cors=True + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + version_string=self.version_string, ) finally: try: @@ -183,7 +242,7 @@ class JsonResource(HttpServer, resource.Resource): """ This implements the HttpServer interface and provides JSON support for Resources. - Register callbacks via register_path() + Register callbacks via register_paths() Callbacks can return a tuple of status code and a dict in which case the the dict will automatically be sent to the client as a JSON object. @@ -230,57 +289,62 @@ class JsonResource(HttpServer, resource.Resource): This checks if anyone has registered a callback for that method and path. """ - if request.method == "OPTIONS": - self._send_response(request, 200, {}) - return + callback, group_dict = self._get_handler_for_request(request) - # Loop through all the registered callbacks to check if the method - # and path regex match - for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) - if not m: - continue + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback - # We found a match! Trigger callback and then return the - # returned response. We pass both the request and any - # matched groups from the regex to the callback. + request_metrics.name = servlet_classname + requests_counter.inc(request.method, servlet_classname) - callback = path_entry.callback + # Now trigger the callback. If it returns a response, we send it + # here. If it throws an exception, that is handled by the wrapper + # installed by @request_handler. - kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value - for name, value in m.groupdict().items() - }) + kwargs = intern_dict({ + name: urllib.unquote(value).decode("UTF-8") if value else value + for name, value in group_dict.items() + }) - callback_return = yield callback(request, **kwargs) - if callback_return is not None: - code, response = callback_return - self._send_response(request, code, response) + callback_return = yield callback(request, **kwargs) + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback + def _get_handler_for_request(self, request): + """Finds a callback method to handle the given request - request_metrics.name = servlet_classname + Args: + request (twisted.web.http.Request): + + Returns: + Tuple[Callable, dict[str, str]]: callback method, and the dict + mapping keys to path components as specified in the handler's + path match regexp. - return + The callback will normally be a method registered via + register_paths, so will return (possibly via Deferred) either + None, or a tuple of (http code, response body). + """ + if request.method == b"OPTIONS": + return _options_handler, {} + + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request.method, []): + m = path_entry.pattern.match(request.path) + if m: + # We found a match! + return path_entry.callback, m.groupdict() # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - raise UnrecognizedRequestError() + return _unrecognised_request_handler, {} def _send_response(self, request, code, response_json_object, response_code_message=None): - # could alternatively use request.notifyFinish() and flip a flag when - # the Deferred fires, but since the flag is RIGHT THERE it seems like - # a waste. - if request._disconnected: - logger.warn( - "Not sending response to request %s, already disconnected.", - request) - return - outgoing_responses_counter.inc(request.method, str(code)) # TODO: Only enable CORS for the requests that need it. @@ -294,6 +358,34 @@ class JsonResource(HttpServer, resource.Resource): ) +def _options_handler(request): + """Request handler for OPTIONS requests + + This is a request handler suitable for return from + _get_handler_for_request. It returns a 200 and an empty body. + + Args: + request (twisted.web.http.Request): + + Returns: + Tuple[int, dict]: http code, response body. + """ + return 200, {} + + +def _unrecognised_request_handler(request): + """Request handler for unrecognised requests + + This is a request handler suitable for return from + _get_handler_for_request. It actually just raises an + UnrecognizedRequestError. + + Args: + request (twisted.web.http.Request): + """ + raise UnrecognizedRequestError() + + class RequestMetrics(object): def start(self, clock, name): self.start = clock.time_msec() @@ -314,7 +406,7 @@ class RequestMetrics(object): ) return - incoming_requests_counter.inc(request.method, self.name, tag) + response_count.inc(request.method, self.name, tag) response_timer.inc_by( clock.time_msec() - self.start, request.method, @@ -333,9 +425,14 @@ class RequestMetrics(object): context.db_txn_count, request.method, self.name, tag ) response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, self.name, tag + context.db_txn_duration_ms / 1000., request.method, self.name, tag + ) + response_db_sched_duration.inc_by( + context.db_sched_duration_ms / 1000., request.method, self.name, tag ) + response_size.inc_by(request.sentLength, request.method, self.name, tag) + class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" @@ -356,14 +453,22 @@ class RootRedirect(resource.Resource): def respond_with_json(request, code, json_object, send_cors=False, response_code_message=None, pretty_print=False, version_string="", canonical_json=True): + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warn( + "Not sending response to request %s, already disconnected.", + request) + return + if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + "\n" else: if canonical_json or synapse.events.USE_FROZEN_DICTS: json_bytes = encode_canonical_json(json_object) else: - # ujson doesn't like frozen_dicts. - json_bytes = ujson.dumps(json_object, ensure_ascii=False) + json_bytes = simplejson.dumps(json_object) return respond_with_json_bytes( request, code, json_bytes, @@ -390,6 +495,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Server", version_string) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") if send_cors: set_cors_headers(request) @@ -412,7 +518,7 @@ def set_cors_headers(request): ) request.setHeader( "Access-Control-Allow-Headers", - "Origin, X-Requested-With, Content-Type, Accept" + "Origin, X-Requested-With, Content-Type, Accept, Authorization" ) @@ -437,9 +543,9 @@ def finish_request(request): def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( - "User-Agent", default=[] + b"User-Agent", default=[] ) for user_agent in user_agents: - if "curl" in user_agent: + if b"curl" in user_agent: return True return False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 8c22d6f00f..ef8e62901b 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer_from_args(args, name, default=None, required=False): if name in args: try: return int(args[name][0]) - except: + except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message) else: @@ -88,7 +88,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): "true": True, "false": False, }[args[name][0]] - except: + except Exception: message = ( "Boolean query parameter %r must be one of" " ['true', 'false']" @@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False, return default -def parse_json_value_from_request(request): +def parse_json_value_from_request(request, allow_empty_body=False): """Parse a JSON value from the body of a twisted HTTP request. Args: request: the twisted HTTP request. + allow_empty_body (bool): if True, an empty body will be accepted and + turned into None Returns: The JSON value. @@ -162,28 +164,39 @@ def parse_json_value_from_request(request): """ try: content_bytes = request.content.read() - except: + except Exception: raise SynapseError(400, "Error reading JSON content.") + if not content_bytes and allow_empty_body: + return None + try: content = simplejson.loads(content_bytes) - except simplejson.JSONDecodeError: + except Exception as e: + logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content -def parse_json_object_from_request(request): +def parse_json_object_from_request(request, allow_empty_body=False): """Parse a JSON object from the body of a twisted HTTP request. Args: request: the twisted HTTP request. + allow_empty_body (bool): if True, an empty body will be accepted and + turned into an empty dict. Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request(request) + content = parse_json_value_from_request( + request, allow_empty_body=allow_empty_body, + ) + + if allow_empty_body and content is None: + return {} if type(content) != dict: message = "Content must be a JSON object." @@ -192,6 +205,16 @@ def parse_json_object_from_request(request): return content +def assert_params_in_request(body, required): + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + class RestServlet(object): """ A Synapse REST Servlet. diff --git a/synapse/http/site.py b/synapse/http/site.py index 4b09d7ee66..c8b46e1af2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ import logging import re import time -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') class SynapseRequest(Request): @@ -43,12 +43,12 @@ class SynapseRequest(Request): def get_redacted_uri(self): return ACCESS_TOKEN_RE.sub( - r'\1<redacted>\3', + br'\1<redacted>\3', self.uri ) def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def started_processing(self): self.site.access_logger.info( @@ -66,14 +66,15 @@ class SynapseRequest(Request): context = LoggingContext.current_context() ru_utime, ru_stime = context.get_resource_usage() db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration - except: + db_txn_duration_ms = context.db_txn_duration_ms + db_sched_duration_ms = context.db_sched_duration_ms + except Exception: ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) + db_txn_count, db_txn_duration_ms = (0, 0) self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" + " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" " %sB %s \"%s %s %s\" \"%s\"", self.getClientIP(), self.site.site_tag, @@ -81,7 +82,8 @@ class SynapseRequest(Request): int(time.time() * 1000) - self.start_time, int(ru_utime * 1000), int(ru_stime * 1000), - int(db_txn_duration * 1000), + db_sched_duration_ms, + db_txn_duration_ms, int(db_txn_count), self.sentLength, self.code, diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 2265e6e8d6..e3b831db67 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -17,12 +17,13 @@ import logging import functools import time import gc +import platform from twisted.internet import reactor from .metric import ( CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, - MemoryUsageMetric, + MemoryUsageMetric, GaugeMetric, ) from .process_collector import register_process_collector @@ -30,6 +31,7 @@ from .process_collector import register_process_collector logger = logging.getLogger(__name__) +running_on_pypy = platform.python_implementation() == 'PyPy' all_metrics = [] all_collectors = [] @@ -57,15 +59,38 @@ class Metrics(object): return metric def register_counter(self, *args, **kwargs): + """ + Returns: + CounterMetric + """ return self._register(CounterMetric, *args, **kwargs) + def register_gauge(self, *args, **kwargs): + """ + Returns: + GaugeMetric + """ + return self._register(GaugeMetric, *args, **kwargs) + def register_callback(self, *args, **kwargs): + """ + Returns: + CallbackMetric + """ return self._register(CallbackMetric, *args, **kwargs) def register_distribution(self, *args, **kwargs): + """ + Returns: + DistributionMetric + """ return self._register(DistributionMetric, *args, **kwargs) def register_cache(self, *args, **kwargs): + """ + Returns: + CacheMetric + """ return self._register(CacheMetric, *args, **kwargs) @@ -126,6 +151,32 @@ reactor_metrics = get_metrics_for("python.twisted.reactor") tick_time = reactor_metrics.register_distribution("tick_time") pending_calls_metric = reactor_metrics.register_distribution("pending_calls") +synapse_metrics = get_metrics_for("synapse") + +# Used to track where various components have processed in the event stream, +# e.g. federation sending, appservice sending, etc. +event_processing_positions = synapse_metrics.register_gauge( + "event_processing_positions", labels=["name"], +) + +# Used to track the current max events stream position +event_persisted_position = synapse_metrics.register_gauge( + "event_persisted_position", +) + +# Used to track the received_ts of the last event processed by various +# components +event_processing_last_ts = synapse_metrics.register_gauge( + "event_processing_last_ts", labels=["name"], +) + +# Used to track the lag processing events. This is the time difference +# between the last processed event's received_ts and the time it was +# finished being processed. +event_processing_lag = synapse_metrics.register_gauge( + "event_processing_lag", labels=["name"], +) + def runUntilCurrentTimer(func): @@ -146,13 +197,21 @@ def runUntilCurrentTimer(func): num_pending += 1 num_pending += len(reactor.threadCallQueue) - start = time.time() * 1000 ret = func(*args, **kwargs) end = time.time() * 1000 + + # record the amount of wallclock time spent running pending calls. + # This is a proxy for the actual amount of time between reactor polls, + # since about 25% of time is actually spent running things triggered by + # I/O events, but that is harder to capture without rewriting half the + # reactor. tick_time.inc_by(end - start) pending_calls_metric.inc_by(num_pending) + if running_on_pypy: + return ret + # Check if we need to do a manual GC (since its been disabled), and do # one if necessary. threshold = gc.get_threshold() @@ -185,6 +244,7 @@ try: # We manually run the GC each reactor tick so that we can get some metrics # about time spent doing GC, - gc.disable() + if not running_on_pypy: + gc.disable() except AttributeError: pass diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index e87b2b80a7..fbba94e633 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -15,18 +15,39 @@ from itertools import chain +import logging +import re +logger = logging.getLogger(__name__) -# TODO(paul): I can't believe Python doesn't have one of these -def map_concat(func, items): - # flatten a list-of-lists - return list(chain.from_iterable(map(func, items))) + +def flatten(items): + """Flatten a list of lists + + Args: + items: iterable[iterable[X]] + + Returns: + list[X]: flattened list + """ + return list(chain.from_iterable(items)) class BaseMetric(object): + """Base class for metrics which report a single value per label set + """ - def __init__(self, name, labels=[]): - self.name = name + def __init__(self, name, labels=[], alternative_names=[]): + """ + Args: + name (str): principal name for this metric + labels (list(str)): names of the labels which will be reported + for this metric + alternative_names (iterable(str)): list of alternative names for + this metric. This can be useful to provide a migration path + when renaming metrics. + """ + self._names = [name] + list(alternative_names) self.labels = labels # OK not to clone as we never write it def dimension(self): @@ -36,8 +57,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: some kind of value escape - return '"%s"' % (value) + return '"%s"' % (_escape_label_value(value),) def _render_key(self, values): if self.is_scalar(): @@ -47,19 +67,60 @@ class BaseMetric(object): for k, v in zip(self.labels, values)]) ) + def _render_for_labels(self, label_values, value): + """Render this metric for a single set of labels + + Args: + label_values (list[str]): values for each of the labels + value: value of the metric at with these labels + + Returns: + iterable[str]: rendered metric + """ + rendered_labels = self._render_key(label_values) + return ( + "%s%s %.12g" % (name, rendered_labels, value) + for name in self._names + ) + + def render(self): + """Render this metric + + Each metric is rendered as: + + name{label1="val1",label2="val2"} value + + https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details + + Returns: + iterable[str]: rendered metrics + """ + raise NotImplementedError() + class CounterMetric(BaseMetric): """The simplest kind of metric; one that stores a monotonically-increasing - integer that counts events.""" + value that counts events or running totals. + + Example use cases for Counters: + - Number of requests processed + - Number of items that were inserted into a queue + - Total amount of data that a system has processed + Counters can only go up (and be reset when the process restarts). + """ def __init__(self, *args, **kwargs): super(CounterMetric, self).__init__(*args, **kwargs) + # dict[list[str]]: value for each set of label values. the keys are the + # label values, in the same order as the labels in self.labels. + # + # (if the metric is a scalar, the (single) key is the empty tuple). self.counts = {} # Scalar metrics are never empty if self.is_scalar(): - self.counts[()] = 0 + self.counts[()] = 0. def inc_by(self, incr, *values): if len(values) != self.dimension(): @@ -77,11 +138,41 @@ class CounterMetric(BaseMetric): def inc(self, *values): self.inc_by(1, *values) - def render_item(self, k): - return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] + def render(self): + return flatten( + self._render_for_labels(k, self.counts[k]) + for k in sorted(self.counts.keys()) + ) + + +class GaugeMetric(BaseMetric): + """A metric that can go up or down + """ + + def __init__(self, *args, **kwargs): + super(GaugeMetric, self).__init__(*args, **kwargs) + + # dict[list[str]]: value for each set of label values. the keys are the + # label values, in the same order as the labels in self.labels. + # + # (if the metric is a scalar, the (single) key is the empty tuple). + self.guages = {} + + def set(self, v, *values): + if len(values) != self.dimension(): + raise ValueError( + "Expected as many values to inc() as labels (%d)" % (self.dimension()) + ) + + # TODO: should assert that the tag values are all strings + + self.guages[values] = v def render(self): - return map_concat(self.render_item, sorted(self.counts.keys())) + return flatten( + self._render_for_labels(k, self.guages[k]) + for k in sorted(self.guages.keys()) + ) class CallbackMetric(BaseMetric): @@ -95,13 +186,19 @@ class CallbackMetric(BaseMetric): self.callback = callback def render(self): - value = self.callback() + try: + value = self.callback() + except Exception: + logger.exception("Failed to render %s", self.name) + return ["# FAILED to render " + self.name] if self.is_scalar(): - return ["%s %.12g" % (self.name, value)] + return list(self._render_for_labels([], value)) - return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) - for k in sorted(value.keys())] + return flatten( + self._render_for_labels(k, value[k]) + for k in sorted(value.keys()) + ) class DistributionMetric(object): @@ -126,7 +223,9 @@ class DistributionMetric(object): class CacheMetric(object): - __slots__ = ("name", "cache_name", "hits", "misses", "size_callback") + __slots__ = ( + "name", "cache_name", "hits", "misses", "evicted_size", "size_callback", + ) def __init__(self, name, size_callback, cache_name): self.name = name @@ -134,6 +233,7 @@ class CacheMetric(object): self.hits = 0 self.misses = 0 + self.evicted_size = 0 self.size_callback = size_callback @@ -143,6 +243,9 @@ class CacheMetric(object): def inc_misses(self): self.misses += 1 + def inc_evictions(self, size=1): + self.evicted_size += size + def render(self): size = self.size_callback() hits = self.hits @@ -152,6 +255,9 @@ class CacheMetric(object): """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), + """%s:evicted_size{name="%s"} %d""" % ( + self.name, self.cache_name, self.evicted_size + ), ] @@ -193,3 +299,29 @@ class MemoryUsageMetric(object): "process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:count %d" % len_rss, ] + + +def _escape_character(m): + """Replaces a single character with its escape sequence. + + Args: + m (re.MatchObject): A match object whose first group is the single + character to replace + + Returns: + str + """ + c = m.group(1) + if c == "\\": + return "\\\\" + elif c == "\"": + return "\\\"" + elif c == "\n": + return "\\n" + return c + + +def _escape_label_value(value): + """Takes a label value and escapes quotes, newlines and backslashes + """ + return re.sub(r"([\n\"\\])", _escape_character, value) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py new file mode 100644 index 0000000000..097c844d31 --- /dev/null +++ b/synapse/module_api/__init__.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.types import UserID + + +class ModuleApi(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, auth_handler): + self.hs = hs + + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = auth_handler + + def get_user_by_req(self, req, allow_guest=False): + """Check the access_token provided for a request + + Args: + req (twisted.web.server.Request): Incoming HTTP request + allow_guest (bool): True if guest users should be allowed. If this + is False, and the access token is for a guest user, an + AuthError will be thrown + Returns: + twisted.internet.defer.Deferred[synapse.types.Requester]: + the requester for this request + Raises: + synapse.api.errors.AuthError: if no user by that token exists, + or the token is invalid. + """ + return self._auth.get_user_by_req(req, allow_guest) + + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id + + Returns: + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. + """ + return self._auth_handler.check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) + + @defer.inlineCallbacks + def invalidate_access_token(self, access_token): + """Invalidate an access token for a user + + Args: + access_token(str): access token + + Returns: + twisted.internet.defer.Deferred - resolves once the access token + has been removed. + + Raises: + synapse.api.errors.AuthError: the access token is invalid + """ + # see if the access token corresponds to a device + user_info = yield self._auth.get_user_by_access_token(access_token) + device_id = user_info.get("device_id") + user_id = user_info["user"].to_string() + if device_id: + # delete the device, which will also delete its access tokens + yield self.hs.get_device_handler().delete_device(user_id, device_id) + else: + # no associated device. Just delete the access token. + yield self._auth_handler.delete_access_token(access_token) + + def run_db_interaction(self, desc, func, *args, **kwargs): + """Run a function with a database connection + + Args: + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object + as well as *args and **kwargs + *args: positional args to be passed to func + **kwargs: named args to be passed to func + + Returns: + Deferred[object]: result of func + """ + return self._store.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/notifier.py b/synapse/notifier.py index 8051a7a842..8355c7d621 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -14,13 +14,17 @@ # limitations under the License. from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.handlers.presence import format_user_presence_state -from synapse.util import DeferredTimedOutError from synapse.util.logutils import log_function -from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.async import ( + ObservableDeferred, add_timeout_to_deferred, + DeferredTimeoutError, +) +from synapse.util.logcontext import PreserveLoggingContext, run_in_background from synapse.util.metrics import Measure from synapse.types import StreamToken from synapse.visibility import filter_events_for_client @@ -37,6 +41,10 @@ metrics = synapse.metrics.get_metrics_for(__name__) notified_events_counter = metrics.register_counter("notified_events") +users_woken_by_stream_counter = metrics.register_counter( + "users_woken_by_stream", labels=["stream"] +) + # TODO(paul): Should be shared somewhere def count(func, l): @@ -73,6 +81,13 @@ class _NotifierUserStream(object): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token + + # The last token for which we should wake up any streams that have a + # token that comes before it. This gets updated everytime we get poked. + # We start it at the current token since if we get any streams + # that have a token from before we have no idea whether they should be + # woken up or not, so lets just wake them up. + self.last_notified_token = current_token self.last_notified_ms = time_now_ms with PreserveLoggingContext(): @@ -89,9 +104,12 @@ class _NotifierUserStream(object): self.current_token = self.current_token.copy_and_advance( stream_key, stream_id ) + self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred + users_woken_by_stream_counter.inc(stream_key) + with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token) @@ -113,8 +131,14 @@ class _NotifierUserStream(object): def new_listener(self, token): """Returns a deferred that is resolved when there is a new token greater than the given token. + + Args: + token: The token from which we are streaming from, i.e. we shouldn't + notify for things that happened before this. """ - if self.current_token.is_after(token): + # Immediately wake up stream if something has already since happened + # since their last token. + if self.last_notified_token.is_after(token): return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) @@ -123,6 +147,7 @@ class _NotifierUserStream(object): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 class Notifier(object): @@ -142,6 +167,8 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] + self.replication_callbacks = [] + self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -181,7 +208,12 @@ class Notifier(object): lambda: len(self.user_to_user_stream), ) - @preserve_fn + def add_replication_callback(self, cb): + """Add a callback that will be called when some new data is available. + Callback is not given any arguments. + """ + self.replication_callbacks.append(cb) + def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -195,15 +227,13 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - with PreserveLoggingContext(): - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + self._notify_pending_new_room_events(max_room_stream_id) - self.notify_replication() + self.notify_replication() - @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -221,11 +251,10 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.appservice_handler.notify_interested_services(room_stream_id) + run_in_background(self._notify_app_services, room_stream_id) if self.federation_sender: self.federation_sender.notify_new_events(room_stream_id) @@ -239,7 +268,13 @@ class Notifier(object): rooms=[event.room_id], ) - @preserve_fn + @defer.inlineCallbacks + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") + def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. @@ -261,17 +296,15 @@ class Notifier(object): for user_stream in user_streams: try: user_stream.notify(stream_key, new_token, time_now_ms) - except: + except Exception: logger.exception("Failed to notify listener") self.notify_replication() - @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" - with PreserveLoggingContext(): - self.notify_replication() + self.notify_replication() @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, @@ -283,8 +316,7 @@ class Notifier(object): if user_stream is None: current_token = yield self.event_sources.get_current_token() if room_ids is None: - rooms = yield self.store.get_rooms_for_user(user_id) - room_ids = [room.room_id for room in rooms] + room_ids = yield self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -294,40 +326,45 @@ class Notifier(object): self._register_with_keys(user_stream) result = None + prev_token = from_token if timeout: end_time = self.clock.time_msec() + timeout - prev_token = from_token while not result: try: - current_token = user_stream.current_token - - result = yield callback(prev_token, current_token) - if result: - break - now = self.clock.time_msec() if end_time <= now: break # Now we wait for the _NotifierUserStream to be told there # is a new token. - # We need to supply the token we supplied to callback so - # that we don't miss any current_token updates. - prev_token = current_token listener = user_stream.new_listener(prev_token) + add_timeout_to_deferred( + listener.deferred, + (end_time - now) / 1000., + ) with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) - except DeferredTimedOutError: + yield listener.deferred + + current_token = user_stream.current_token + + result = yield callback(prev_token, current_token) + if result: + break + + # Update the prev_token to the current_token since nothing + # has happened between the old prev_token and the current_token + prev_token = current_token + except DeferredTimeoutError: break except defer.CancelledError: break - else: + + if result is None: + # This happened if there was no timeout or if the timeout had + # already expired. current_token = user_stream.current_token - result = yield callback(from_token, current_token) + result = yield callback(prev_token, current_token) defer.returnValue(result) @@ -388,6 +425,15 @@ class Notifier(object): new_events, is_peeking=is_peeking, ) + elif name == "presence": + now = self.clock.time_msec() + new_events[:] = [ + { + "type": "m.presence", + "content": format_user_presence_state(event, now), + } + for event in new_events + ] events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) @@ -420,8 +466,7 @@ class Notifier(object): @defer.inlineCallbacks def _get_room_ids(self, user, explicit_room_id): - joined_rooms = yield self.store.get_rooms_for_user(user.to_string()) - joined_room_ids = map(lambda r: r.room_id, joined_rooms) + joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: defer.returnValue(([explicit_room_id], True)) @@ -478,6 +523,15 @@ class Notifier(object): self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) + # the callbacks may well outlast the current request, so we run + # them in the sentinel logcontext. + # + # (ideally it would be up to the callbacks to know if they were + # starting off background processes and drop the logcontext + # accordingly, but that requires more changes) + for cb in self.replication_callbacks: + cb() + @defer.inlineCallbacks def wait_for_replication(self, callback, timeout): """Wait for an event to happen. @@ -506,13 +560,14 @@ class Notifier(object): if end_time <= now: break + add_timeout_to_deferred( + listener.deferred.addTimeout, + (end_time - now) / 1000., + ) try: with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) - except DeferredTimedOutError: + yield listener.deferred + except DeferredTimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 3f75d3f921..8f619a7a1b 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from .bulk_push_rule_evaluator import evaluator_for_event +from .bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure @@ -24,11 +24,12 @@ import logging logger = logging.getLogger(__name__) -class ActionGenerator: +class ActionGenerator(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastore() + self.bulk_evaluator = BulkPushRuleEvaluator(hs) # really we want to get all user ids and all profile tags too, # since we want the actions for each profile tag for every user and # also actions for a client with no profile tag for each user. @@ -38,16 +39,7 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): - with Measure(self.clock, "evaluator_for_event"): - bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context - ) - with Measure(self.clock, "action_for_event_by_user"): - actions_by_user = yield bulk_evaluator.action_for_event_by_user( + yield self.bulk_evaluator.action_for_event_by_user( event, context ) - - context.push_actions = [ - (uid, actions) for uid, actions in actions_by_user.items() - ] diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 85effdfa46..7a18afe5f9 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [ } ] }, + { + 'rule_id': 'global/override/.m.rule.roomnotif', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '@room', + '_id': '_roomnotif_content', + }, + { + 'kind': 'sender_notification_permission', + 'key': 'room', + '_id': '_roomnotif_pl', + }, + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': True, + } + ] + } ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 78b095c903..7c680659b6 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,88 +20,166 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes -from synapse.visibility import filter_events_for_clients_context +from synapse.event_auth import get_user_power_level +from synapse.api.constants import EventTypes, Membership +from synapse.metrics import get_metrics_for +from synapse.util.caches import metrics as cache_metrics +from synapse.util.caches.descriptors import cached +from synapse.util.async import Linearizer +from synapse.state import POWER_KEY + +from collections import namedtuple logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def evaluator_for_event(event, hs, store, context): - rules_by_user = yield store.bulk_get_push_rules_for_room( - event, context - ) - - # if this event is an invite event, we may need to run rules for the user - # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': - invited_user = event.state_key - if invited_user and hs.is_mine_id(invited_user): - has_pusher = yield store.user_has_pusher(invited_user) - if has_pusher: - rules_by_user = dict(rules_by_user) - rules_by_user[invited_user] = yield store.get_push_rules_for_user( - invited_user - ) +rules_by_room = {} - defer.returnValue(BulkPushRuleEvaluator( - event.room_id, rules_by_user, store - )) +push_metrics = get_metrics_for(__name__) +push_rules_invalidation_counter = push_metrics.register_counter( + "push_rules_invalidation_counter" +) +push_rules_state_size_counter = push_metrics.register_counter( + "push_rules_state_size_counter" +) -class BulkPushRuleEvaluator: - """ - Runs push rules for all users in a room. - This is faster than running PushRuleEvaluator for each user because it - fetches all the rules for all the users in one (batched) db query - rather than doing multiple queries per-user. It currently uses - the same logic to run the actual rules, but could be optimised further - (see https://matrix.org/jira/browse/SYN-562) +# Measures whether we use the fast path of using state deltas, or if we have to +# recalculate from scratch +push_rules_delta_state_cache_metric = cache_metrics.register_cache( + "cache", + size_callback=lambda: 0, # Meaningless size, as this isn't a cache that stores values + cache_name="push_rules_delta_state_cache_metric", +) + + +class BulkPushRuleEvaluator(object): + """Calculates the outcome of push rules for an event for all users in the + room at once. """ - def __init__(self, room_id, rules_by_user, store): - self.room_id = room_id - self.rules_by_user = rules_by_user - self.store = store + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + self.room_push_rule_cache_metrics = cache_metrics.register_cache( + "cache", + size_callback=lambda: 0, # There's not good value for this + cache_name="room_push_rule_cache", + ) @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): - actions_by_user = {} + def _get_rules_for_event(self, event, context): + """This gets the rules for all users in the room at the time of the event, + as well as the push rules for the invitee if the event is an invite. + + Returns: + dict of user_id -> push_rules + """ + room_id = event.room_id + rules_for_room = self._get_rules_for_room(room_id) + + rules_by_user = yield rules_for_room.get_rules(event, context) + + # if this event is an invite event, we may need to run rules for the user + # who's been invited, otherwise they won't get told they've been invited + if event.type == 'm.room.member' and event.content['membership'] == 'invite': + invited = event.state_key + if invited and self.hs.is_mine_id(invited): + has_pusher = yield self.store.user_has_pusher(invited) + if has_pusher: + rules_by_user = dict(rules_by_user) + rules_by_user[invited] = yield self.store.get_push_rules_for_user( + invited + ) - # None of these users can be peeking since this list of users comes - # from the set of users in the room, so we know for sure they're all - # actually in the room. - user_tuples = [ - (u, False) for u in self.rules_by_user.keys() - ] + defer.returnValue(rules_by_user) - filtered_by_user = yield filter_events_for_clients_context( - self.store, user_tuples, [event], {event.event_id: context} + @cached() + def _get_rules_for_room(self, room_id): + """Get the current RulesForRoom object for the given room id + + Returns: + RulesForRoom + """ + # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + # before any lookup methods get called on it as otherwise there may be + # a race if invalidate_all gets called (which assumes its in the cache) + return RulesForRoom( + self.hs, room_id, self._get_rules_for_room.cache, + self.room_push_rule_cache_metrics, ) + @defer.inlineCallbacks + def _get_power_levels_and_sender_level(self, event, context): + pl_event_id = context.prev_state_ids.get(POWER_KEY) + if pl_event_id: + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case + pl_event = yield self.store.get_event(pl_event_id) + auth_events = {POWER_KEY: pl_event} + else: + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=False, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.itervalues() + } + + sender_level = get_user_power_level(event.sender, auth_events) + + pl_event = auth_events.get(POWER_KEY) + + defer.returnValue((pl_event.content if pl_event else {}, sender_level)) + + @defer.inlineCallbacks + def action_for_event_by_user(self, event, context): + """Given an event and context, evaluate the push rules and insert the + results into the event_push_actions_staging table. + + Returns: + Deferred + """ + rules_by_user = yield self._get_rules_for_event(event, context) + actions_by_user = {} + room_members = yield self.store.get_joined_users_from_context( event, context ) - evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) + (power_levels, sender_power_level) = ( + yield self._get_power_levels_and_sender_level(event, context) + ) + + evaluator = PushRuleEvaluatorForEvent( + event, len(room_members), sender_power_level, power_levels, + ) condition_cache = {} - for uid, rules in self.rules_by_user.items(): - display_name = room_members.get(uid, {}).get("display_name", None) + for uid, rules in rules_by_user.iteritems(): + if event.sender == uid: + continue + + if not event.is_state(): + is_ignored = yield self.store.is_ignored_by(event.sender, uid) + if is_ignored: + continue + + display_name = None + profile_info = room_members.get(uid) + if profile_info: + display_name = profile_info.display_name + if not display_name: # Handle the case where we are pushing a membership event to # that user, as they might not be already joined. if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) - filtered = filtered_by_user[uid] - if len(filtered) == 0: - continue - - if filtered[0].sender == uid: - continue - for rule in rules: if 'enabled' in rule and not rule['enabled']: continue @@ -111,9 +190,16 @@ class BulkPushRuleEvaluator: if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] if actions and 'notify' in actions: + # Push rules say we should notify the user of this event actions_by_user[uid] = actions break - defer.returnValue(actions_by_user) + + # Mark in the DB staging area the push actions for users who should be + # notified for this event. (This will then get handled when we persist + # the event) + yield self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, + ) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -134,3 +220,264 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return False return True + + +class RulesForRoom(object): + """Caches push rules for users in a room. + + This efficiently handles users joining/leaving the room by not invalidating + the entire cache for the room. + """ + + def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): + """ + Args: + hs (HomeServer) + room_id (str) + rules_for_room_cache(Cache): The cache object that caches these + RoomsForUser objects. + room_push_rule_cache_metrics (CacheMetric) + """ + self.room_id = room_id + self.is_mine_id = hs.is_mine_id + self.store = hs.get_datastore() + self.room_push_rule_cache_metrics = room_push_rule_cache_metrics + + self.linearizer = Linearizer(name="rules_for_room") + + self.member_map = {} # event_id -> (user_id, state) + self.rules_by_user = {} # user_id -> rules + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + self.state_group = object() + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + self.sequence = 0 + + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + self.uninteresting_user_set = set() + + # We need to be clever on the invalidating caches callbacks, as + # otherwise the invalidation callback holds a reference to the object, + # potentially causing it to leak. + # To get around this we pass a function that on invalidations looks ups + # the RoomsForUser entry in the cache, rather than keeping a reference + # to self around in the callback. + self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) + + @defer.inlineCallbacks + def get_rules(self, event, context): + """Given an event context return the rules for all users who are + currently in the room. + """ + state_group = context.state_group + + if state_group and self.state_group == state_group: + logger.debug("Using cached rules for %r", self.room_id) + self.room_push_rule_cache_metrics.inc_hits() + defer.returnValue(self.rules_by_user) + + with (yield self.linearizer.queue(())): + if state_group and self.state_group == state_group: + logger.debug("Using cached rules for %r", self.room_id) + self.room_push_rule_cache_metrics.inc_hits() + defer.returnValue(self.rules_by_user) + + self.room_push_rule_cache_metrics.inc_misses() + + ret_rules_by_user = {} + missing_member_event_ids = {} + if state_group and self.state_group == context.prev_group: + # If we have a simple delta then we can reuse most of the previous + # results. + ret_rules_by_user = self.rules_by_user + current_state_ids = context.delta_ids + + push_rules_delta_state_cache_metric.inc_hits() + else: + current_state_ids = context.current_state_ids + push_rules_delta_state_cache_metric.inc_misses() + + push_rules_state_size_counter.inc_by(len(current_state_ids)) + + logger.debug( + "Looking for member changes in %r %r", state_group, current_state_ids + ) + + # Loop through to see which member events we've seen and have rules + # for and which we need to fetch + for key in current_state_ids: + typ, user_id = key + if typ != EventTypes.Member: + continue + + if user_id in self.uninteresting_user_set: + continue + + if not self.is_mine_id(user_id): + self.uninteresting_user_set.add(user_id) + continue + + if self.store.get_if_app_services_interested_in_user(user_id): + self.uninteresting_user_set.add(user_id) + continue + + event_id = current_state_ids[key] + + res = self.member_map.get(event_id, None) + if res: + user_id, state = res + if state == Membership.JOIN: + rules = self.rules_by_user.get(user_id, None) + if rules: + ret_rules_by_user[user_id] = rules + continue + + # If a user has left a room we remove their push rule. If they + # joined then we readd it later in _update_rules_with_member_event_ids + ret_rules_by_user.pop(user_id, None) + missing_member_event_ids[user_id] = event_id + + if missing_member_event_ids: + # If we have some memebr events we haven't seen, look them up + # and fetch push rules for them if appropriate. + logger.debug("Found new member events %r", missing_member_event_ids) + yield self._update_rules_with_member_event_ids( + ret_rules_by_user, missing_member_event_ids, state_group, event + ) + else: + # The push rules didn't change but lets update the cache anyway + self.update_cache( + self.sequence, + members={}, # There were no membership changes + rules_by_user=ret_rules_by_user, + state_group=state_group + ) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Returning push rules for %r %r", + self.room_id, ret_rules_by_user.keys(), + ) + defer.returnValue(ret_rules_by_user) + + @defer.inlineCallbacks + def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, + state_group, event): + """Update the partially filled rules_by_user dict by fetching rules for + any newly joined users in the `member_event_ids` list. + + Args: + ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets + updated with any new rules. + member_event_ids (list): List of event ids for membership events that + have happened since the last time we filled rules_by_user + state_group: The state group we are currently computing push rules + for. Used when updating the cache. + """ + sequence = self.sequence + + rows = yield self.store._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids.values(), + retcols=('user_id', 'membership', 'event_id'), + keyvalues={}, + batch_size=500, + desc="_get_rules_for_member_event_ids", + ) + + members = { + row["event_id"]: (row["user_id"], row["membership"]) + for row in rows + } + + # If the event is a join event then it will be in current state evnts + # map but not in the DB, so we have to explicitly insert it. + if event.type == EventTypes.Member: + for event_id in member_event_ids.itervalues(): + if event_id == event.event_id: + members[event_id] = (event.state_key, event.membership) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Found members %r: %r", self.room_id, members.values()) + + interested_in_user_ids = set( + user_id for user_id, membership in members.itervalues() + if membership == Membership.JOIN + ) + + logger.debug("Joined: %r", interested_in_user_ids) + + if_users_with_pushers = yield self.store.get_if_users_have_pushers( + interested_in_user_ids, + on_invalidate=self.invalidate_all_cb, + ) + + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher + ) + + logger.debug("With pushers: %r", user_ids) + + users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + self.room_id, on_invalidate=self.invalidate_all_cb, + ) + + logger.debug("With receipts: %r", users_with_receipts) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in interested_in_user_ids: + user_ids.add(uid) + + rules_by_user = yield self.store.bulk_get_push_rules( + user_ids, on_invalidate=self.invalidate_all_cb, + ) + + ret_rules_by_user.update( + item for item in rules_by_user.iteritems() if item[0] is not None + ) + + self.update_cache(sequence, members, ret_rules_by_user, state_group) + + def invalidate_all(self): + # Note: Don't hand this function directly to an invalidation callback + # as it keeps a reference to self and will stop this instance from being + # GC'd if it gets dropped from the rules_to_user cache. Instead use + # `self.invalidate_all_cb` + logger.debug("Invalidating RulesForRoom for %r", self.room_id) + self.sequence += 1 + self.state_group = object() + self.member_map = {} + self.rules_by_user = {} + push_rules_invalidation_counter.inc() + + def update_cache(self, sequence, members, rules_by_user, state_group): + if sequence == self.sequence: + self.member_map.update(members) + self.rules_by_user = rules_by_user + self.state_group = state_group + + +class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): + # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, + # which namedtuple does for us (i.e. two _CacheContext are the same if + # their caches and keys match). This is important in particular to + # dedupe when we add callbacks to lru cache nodes, otherwise the number + # of callbacks would grow. + def __call__(self): + rules = self.cache.get(self.room_id, None, update_metrics=False) + if rules: + rules.invalidate_all() diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 2eb325c7c7..ba7286cb72 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -21,7 +21,6 @@ import logging from synapse.util.metrics import Measure from synapse.util.logcontext import LoggingContext -from mailer import Mailer logger = logging.getLogger(__name__) @@ -56,8 +55,10 @@ class EmailPusher(object): This shares quite a bit of code with httpusher: it would be good to factor out the common parts """ - def __init__(self, hs, pusherdict): + def __init__(self, hs, pusherdict, mailer): self.hs = hs + self.mailer = mailer + self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.pusher_id = pusherdict['id'] @@ -73,23 +74,16 @@ class EmailPusher(object): self.processing = False - if self.hs.config.email_enable_notifs: - if 'data' in pusherdict and 'brand' in pusherdict['data']: - app_name = pusherdict['data']['brand'] - else: - app_name = self.hs.config.email_app_name - - self.mailer = Mailer(self.hs, app_name) - else: - self.mailer = None - @defer.inlineCallbacks def on_started(self): if self.mailer is not None: - self.throttle_params = yield self.store.get_throttle_params_by_room( - self.pusher_id - ) - yield self._process() + try: + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) + yield self._process() + except Exception: + logger.exception("Error starting email pusher") def on_stop(self): if self.timed_call: @@ -130,7 +124,7 @@ class EmailPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break @@ -218,7 +212,8 @@ class EmailPusher(object): ) def seconds_until(self, ts_msec): - return (ts_msec - self.clock.time_msec()) / 1000 + secs = (ts_msec - self.clock.time_msec()) / 1000 + return max(secs, 0) def get_room_throttle_ms(self, room_id): if room_id in self.throttle_params: diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index c0f8176e3d..b077e1a446 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.push import PusherConfigException +import logging from twisted.internet import defer, reactor from twisted.internet.error import AlreadyCalled, AlreadyCancelled -import logging -import push_rule_evaluator -import push_tools - +from . import push_rule_evaluator +from . import push_tools +import synapse +from synapse.push import PusherConfigException from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure logger = logging.getLogger(__name__) +metrics = synapse.metrics.get_metrics_for(__name__) + +http_push_processed_counter = metrics.register_counter( + "http_pushes_processed", +) + +http_push_failed_counter = metrics.register_counter( + "http_pushes_failed", +) + class HttpPusher(object): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes @@ -84,7 +94,10 @@ class HttpPusher(object): @defer.inlineCallbacks def on_started(self): - yield self._process() + try: + yield self._process() + except Exception: + logger.exception("Error starting http pusher") @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): @@ -131,7 +144,7 @@ class HttpPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break @@ -151,9 +164,16 @@ class HttpPusher(object): self.user_id, self.last_stream_ordering, self.max_stream_ordering ) + logger.info( + "Processing %i unprocessed push actions for %s starting at " + "stream_ordering %s", + len(unprocessed), self.name, self.last_stream_ordering, + ) + for push_action in unprocessed: processed = yield self._process_one(push_action) if processed: + http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action['stream_ordering'] yield self.store.update_pusher_last_stream_ordering_and_success( @@ -168,6 +188,7 @@ class HttpPusher(object): self.failing_since ) else: + http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() yield self.store.update_pusher_failing_since( @@ -244,6 +265,26 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): + if self.data.get('format') == 'event_id_only': + d = { + 'notification': { + 'event_id': event.event_id, + 'room_id': event.room_id, + 'counts': { + 'unread': badge, + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + } + ] + } + } + defer.returnValue(d) + ctx = yield push_tools.get_context_for_event( self.store, self.state_handler, event, self.user_id ) @@ -275,7 +316,7 @@ class HttpPusher(object): if event.type == 'm.room.member': d['notification']['membership'] = event.content['membership'] d['notification']['user_is_target'] = event.state_key == self.user_id - if 'content' in event: + if self.hs.config.push_include_content and 'content' in event: d['notification']['content'] = event.content # We no longer send aliases separately, instead, we send the human @@ -294,8 +335,11 @@ class HttpPusher(object): defer.returnValue([]) try: resp = yield self.http_client.post_json_get_json(self.url, notification_dict) - except: - logger.warn("Failed to push %s ", self.url) + except Exception: + logger.warn( + "Failed to push event %s to %s", + event.event_id, self.name, exc_info=True, + ) defer.returnValue(False) rejected = [] if 'rejected' in resp: @@ -304,7 +348,7 @@ class HttpPusher(object): @defer.inlineCallbacks def _send_badge(self, badge): - logger.info("Sending updated badge count %d to %r", badge, self.user_id) + logger.info("Sending updated badge count %d to %s", badge, self.name) d = { 'notification': { 'id': '', @@ -325,8 +369,11 @@ class HttpPusher(object): } try: resp = yield self.http_client.post_json_get_json(self.url, d) - except: - logger.exception("Failed to push %s ", self.url) + except Exception: + logger.warn( + "Failed to send badge count to %s", + self.name, exc_info=True, + ) defer.returnValue(False) rejected = [] if 'rejected' in resp: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 62d794f22b..b5cd9b426a 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -78,23 +78,17 @@ ALLOWED_ATTRS = { class Mailer(object): - def __init__(self, hs, app_name): + def __init__(self, hs, app_name, notif_template_html, notif_template_text): self.hs = hs + self.notif_template_html = notif_template_html + self.notif_template_text = notif_template_text + self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name + logger.info("Created Mailer for app_name %s" % app_name) - env = jinja2.Environment(loader=loader) - env.filters["format_ts"] = format_ts_filter - env.filters["mxc_to_http"] = self.mxc_to_http_filter - self.notif_template_html = env.get_template( - self.hs.config.email_notif_template_html - ) - self.notif_template_text = env.get_template( - self.hs.config.email_notif_template_text - ) @defer.inlineCallbacks def send_notification_mail(self, app_id, user_id, email_address, @@ -139,7 +133,7 @@ class Mailer(object): @defer.inlineCallbacks def _fetch_room_state(room_id): - room_state = yield self.state_handler.get_current_state_ids(room_id) + room_state = yield self.store.get_current_state_ids(room_id) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email @@ -200,7 +194,11 @@ class Mailer(object): yield sendmail( self.hs.config.email_smtp_host, raw_from, raw_to, multipart_msg.as_string(), - port=self.hs.config.email_smtp_port + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security ) @defer.inlineCallbacks @@ -477,28 +475,6 @@ class Mailer(object): urllib.urlencode(params), ) - def mxc_to_http_filter(self, value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if '#' in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) - fragment = "#" + fragment - - params = { - "width": width, - "height": height, - "method": resize_method, - } - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - self.hs.config.public_baseurl, - serverAndMediaId, - urllib.urlencode(params), - fragment or "", - ) - def safe_markup(raw_html): return jinja2.Markup(bleach.linkify(bleach.clean( @@ -539,3 +515,52 @@ def string_ordinal_total(s): def format_ts_filter(value, format): return time.strftime(format, time.localtime(value / 1000)) + + +def load_jinja2_templates(config): + """Load the jinja2 email templates from disk + + Returns: + (notif_template_html, notif_template_text) + """ + logger.info("loading jinja2") + + loader = jinja2.FileSystemLoader(config.email_template_dir) + env = jinja2.Environment(loader=loader) + env.filters["format_ts"] = format_ts_filter + env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) + + notif_template_html = env.get_template( + config.email_notif_template_html + ) + notif_template_text = env.get_template( + config.email_notif_template_text + ) + + return notif_template_html, notif_template_text + + +def _create_mxc_to_http_filter(config): + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + serverAndMediaId = value[6:] + fragment = None + if '#' in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + fragment = "#" + fragment + + params = { + "width": width, + "height": height, + "method": resize_method, + } + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + config.public_baseurl, + serverAndMediaId, + urllib.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4db76f18bd..3601f2d365 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,7 @@ import logging import re from synapse.types import UserID +from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -28,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") def _room_member_count(ev, condition, room_member_count): + return _test_ineq_condition(condition, room_member_count) + + +def _sender_notification_permission(ev, condition, sender_power_level, power_levels): + notif_level_key = condition.get('key') + if notif_level_key is None: + return False + + notif_levels = power_levels.get('notifications', {}) + room_notif_level = notif_levels.get(notif_level_key, 50) + + return sender_power_level >= room_notif_level + + +def _test_ineq_condition(condition, number): if 'is' not in condition: return False m = INEQUALITY_EXPR.match(condition['is']) @@ -40,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count): rhs = int(rhs) if ineq == '' or ineq == '==': - return room_member_count == rhs + return number == rhs elif ineq == '<': - return room_member_count < rhs + return number < rhs elif ineq == '>': - return room_member_count > rhs + return number > rhs elif ineq == '>=': - return room_member_count >= rhs + return number >= rhs elif ineq == '<=': - return room_member_count <= rhs + return number <= rhs else: return False @@ -64,9 +81,11 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count): + def __init__(self, event, room_member_count, sender_power_level, power_levels): self._event = event self._room_member_count = room_member_count + self._sender_power_level = sender_power_level + self._power_levels = power_levels # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -80,6 +99,10 @@ class PushRuleEvaluatorForEvent(object): return _room_member_count( self._event, condition, self._room_member_count ) + elif condition['kind'] == 'sender_notification_permission': + return _sender_notification_permission( + self._event, condition, self._sender_power_level, self._power_levels, + ) else: return True @@ -125,6 +148,11 @@ class PushRuleEvaluatorForEvent(object): return self._value_cache.get(dotted_key, None) +# Caches (glob, word_boundary) -> regex for push. See _glob_matches +regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +register_cache("regex_push_cache", regex_cache) + + def _glob_matches(glob, value, word_boundary=False): """Tests if value matches glob. @@ -137,47 +165,78 @@ def _glob_matches(glob, value, word_boundary=False): Returns: bool """ + try: - if IS_GLOB.search(glob): - r = re.escape(glob) - - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') - - # handle [abc], [a-z] and [!a-z] style ranges. - r = GLOB_REGEX.sub( - lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) - ), - r, - ) - if word_boundary: - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - - return r.search(value) - else: - r = r + "$" - r = _compile_regex(r) - - return r.match(value) - elif word_boundary: - r = re.escape(glob) - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - - return r.search(value) - else: - return value.lower() == glob.lower() + r = regex_cache.get((glob, word_boundary), None) + if not r: + r = _glob_to_re(glob, word_boundary) + regex_cache[(glob, word_boundary)] = r + return r.search(value) except re.error: logger.warn("Failed to parse glob to regex: %r", glob) return False -def _flatten_dict(d, prefix=[], result={}): +def _glob_to_re(glob, word_boundary): + """Generates regex for a given glob. + + Args: + glob (string) + word_boundary (bool): Whether to match against word boundaries or entire + string. Defaults to False. + + Returns: + regex object + """ + if IS_GLOB.search(glob): + r = re.escape(glob) + + r = r.replace(r'\*', '.*?') + r = r.replace(r'\?', '.') + + # handle [abc], [a-z] and [!a-z] style ranges. + r = GLOB_REGEX.sub( + lambda x: ( + '[%s%s]' % ( + x.group(1) and '^' or '', + x.group(2).replace(r'\\\-', '-') + ) + ), + r, + ) + if word_boundary: + r = _re_word_boundary(r) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + r + "$" + + return re.compile(r, flags=re.IGNORECASE) + elif word_boundary: + r = re.escape(glob) + r = _re_word_boundary(r) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + re.escape(glob) + "$" + return re.compile(r, flags=re.IGNORECASE) + + +def _re_word_boundary(r): + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) + + +def _flatten_dict(d, prefix=[], result=None): + if result is None: + result = {} for key, value in d.items(): if isinstance(value, basestring): result[".".join(prefix + [key])] = value.lower() @@ -185,16 +244,3 @@ def _flatten_dict(d, prefix=[], result={}): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result - - -regex_cache = LruCache(5000) - - -def _compile_regex(regex_str): - r = regex_cache.get(regex_str, None) - if r: - return r - - r = re.compile(regex_str, flags=re.IGNORECASE) - regex_cache[regex_str] = r - return r diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a27476bbad..6835f54e97 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,15 +17,12 @@ from twisted.internet import defer from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event ) -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.get_invited_rooms_for_user)(user_id), - preserve_fn(store.get_rooms_for_user)(user_id), - ], consumeErrors=True)) + invites = yield store.get_invited_rooms_for_user(user_id) + joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", @@ -33,13 +30,13 @@ def get_badge_count(store, user_id): badge = len(invites) - for r in joins: - if r.room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[r.room_id] + for room_id in joins: + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] notifs = yield ( store.get_unread_event_push_actions_by_room_for_user( - r.room_id, user_id, last_unread_event_id + room_id, user_id, last_unread_event_id ) ) # return one badge count per conversation, as count per diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index de9c33b936..5aa6667e91 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from httppusher import HttpPusher +from .httppusher import HttpPusher import logging logger = logging.getLogger(__name__) @@ -26,22 +26,54 @@ logger = logging.getLogger(__name__) # process works fine) try: from synapse.push.emailpusher import EmailPusher -except: + from synapse.push.mailer import Mailer, load_jinja2_templates +except Exception: pass -def create_pusher(hs, pusherdict): - logger.info("trying to create_pusher for %r", pusherdict) +class PusherFactory(object): + def __init__(self, hs): + self.hs = hs - PUSHER_TYPES = { - "http": HttpPusher, - } + self.pusher_types = { + "http": HttpPusher, + } - logger.info("email enable notifs: %r", hs.config.email_enable_notifs) - if hs.config.email_enable_notifs: - PUSHER_TYPES["email"] = EmailPusher - logger.info("defined email pusher type") + logger.info("email enable notifs: %r", hs.config.email_enable_notifs) + if hs.config.email_enable_notifs: + self.mailers = {} # app_name -> Mailer - if pusherdict['kind'] in PUSHER_TYPES: - logger.info("found pusher") - return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict) + templates = load_jinja2_templates(hs.config) + self.notif_template_html, self.notif_template_text = templates + + self.pusher_types["email"] = self._create_email_pusher + + logger.info("defined email pusher type") + + def create_pusher(self, pusherdict): + logger.info("trying to create_pusher for %r", pusherdict) + + if pusherdict['kind'] in self.pusher_types: + logger.info("found pusher") + return self.pusher_types[pusherdict['kind']](self.hs, pusherdict) + + def _create_email_pusher(self, _hs, pusherdict): + app_name = self._app_name_from_pusherdict(pusherdict) + mailer = self.mailers.get(app_name) + if not mailer: + mailer = Mailer( + hs=self.hs, + app_name=app_name, + notif_template_html=self.notif_template_html, + notif_template_text=self.notif_template_text, + ) + self.mailers[app_name] = mailer + return EmailPusher(self.hs, pusherdict, mailer) + + def _app_name_from_pusherdict(self, pusherdict): + if 'data' in pusherdict and 'brand' in pusherdict['data']: + app_name = pusherdict['data']['brand'] + else: + app_name = self.hs.config.email_app_name + + return app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3837be523d..750d11ca38 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -import pusher -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.push.pusher import PusherFactory from synapse.util.async import run_on_reactor - -import logging +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class PusherPool: def __init__(self, _hs): self.hs = _hs + self.pusher_factory = PusherFactory(_hs) self.start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() @@ -48,7 +49,7 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - pusher.create_pusher(self.hs, { + self.pusher_factory.create_pusher({ "id": None, "user_name": user_id, "kind": kind, @@ -102,19 +103,25 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_access_token_id=None): - all = yield self.store.get_all_pushers() - logger.info( - "Removing all pushers for user %s except access tokens id %r", - user_id, except_access_token_id - ) - for p in all: - if p['user_name'] == user_id and p['access_token'] != except_access_token_id: + def remove_pushers_by_access_token(self, user_id, access_tokens): + """Remove the pushers for a given user corresponding to a set of + access_tokens. + + Args: + user_id (str): user to remove pushers for + access_tokens (Iterable[int]): access token *ids* to remove pushers + for + """ + tokens = set(access_tokens) + for p in (yield self.store.get_pushers_by_user_id(user_id)): + if p['access_token'] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher( + p['app_id'], p['pushkey'], p['user_name'], + ) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): @@ -130,13 +137,16 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_notifications)( - min_stream_id, max_stream_id + run_in_background( + p.on_new_notifications, + min_stream_id, max_stream_id, ) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) + except Exception: logger.exception("Exception in pusher on_new_notifications") @defer.inlineCallbacks @@ -157,11 +167,16 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) + run_in_background( + p.on_new_receipts, + min_stream_id, max_stream_id, + ) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) + except Exception: logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks @@ -186,8 +201,8 @@ class PusherPool: logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: try: - p = pusher.create_pusher(self.hs, pusherdict) - except: + p = self.pusher_factory.create_pusher(pusherdict) + except Exception: logger.exception("Couldn't start a pusher: caught Exception") continue if p: @@ -200,7 +215,7 @@ class PusherPool: if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p - preserve_fn(p.on_started)() + run_in_background(p.on_started) logger.info("Started pushers") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 7817b0cd91..216db4d164 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -1,4 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,26 +19,43 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) +# this dict maps from python package name to a list of modules we expect it to +# provide. +# +# the key is a "requirement specifier", as used as a parameter to `pip +# install`[1], or an `install_requires` argument to `setuptools.setup` [2]. +# +# the value is a sequence of strings; each entry should be the name of the +# python module, optionally followed by a version assertion which can be either +# ">=<ver>" or "==<ver>". +# +# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. +# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { + "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], - "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], + "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], + "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"], - "pyopenssl>=0.14": ["OpenSSL>=0.14"], + + # We use crypto.get_elliptic_curve which is only supported in >=0.15 + "pyopenssl>=0.15": ["OpenSSL>=0.15"], + "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], "daemonize": ["daemonize"], - "py-bcrypt": ["bcrypt"], + "bcrypt": ["bcrypt>=3.1.0"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], - "ujson": ["ujson"], "blist": ["blist"], - "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], + "pysaml2>=3.0.0": ["saml2>=3.0.0"], "pymacaroons-pynacl": ["pymacaroons"], "msgpack-python>=0.3.0": ["msgpack"], + "phonenumbers>=8.2.0": ["phonenumbers"], + "six": ["six"], } CONDITIONAL_REQUIREMENTS = { "web_client": { @@ -55,6 +74,9 @@ CONDITIONAL_REQUIREMENTS = { "psutil": { "psutil>=2.0.0": ["psutil>=2.0.0"], }, + "affinity": { + "affinity": ["affinity"], + }, } diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py deleted file mode 100644 index c05a50d7a6..0000000000 --- a/synapse/replication/expire_cache.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET - - -class ExpireCacheResource(Resource): - """ - HTTP endpoint for expiring storage caches. - - POST /_synapse/replication/expire_cache HTTP/1.1 - Content-Type: application/json - - { - "invalidate": [ - { - "name": "func_name", - "keys": ["key1", "key2"] - } - ] - } - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.store = hs.get_datastore() - self.version_string = hs.version_string - self.clock = hs.get_clock() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - for row in content["invalidate"]: - name = row["name"] - keys = tuple(row["keys"]) - - getattr(self.store, name).invalidate(keys) - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py new file mode 100644 index 0000000000..1d7a607529 --- /dev/null +++ b/synapse/replication/http/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import JsonResource +from synapse.replication.http import membership, send_event + + +REPLICATION_PREFIX = "/_synapse/replication" + + +class ReplicationRestResource(JsonResource): + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + self.register_servlets(hs) + + def register_servlets(self, hs): + send_event.register_servlets(hs, self) + membership.register_servlets(hs, self) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py new file mode 100644 index 0000000000..e66c4e881f --- /dev/null +++ b/synapse/replication/http/membership.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, MatrixCodeMessageException +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import Requester, UserID +from synapse.util.distributor import user_left_room, user_joined_room + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def remote_join(client, host, port, requester, remote_room_hosts, + room_id, user_id, content): + """Ask the master to do a remote join for the given user to the given room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and join via + room_id (str) + user_id (str) + content (dict): The event content to use for the join event + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + "content": content, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def remote_reject_invite(client, host, port, requester, remote_room_hosts, + room_id, user_id): + """Ask master to reject the invite for the user and room. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and reject via + room_id (str) + user_id (str) + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def get_or_register_3pid_guest(client, host, port, requester, + medium, address, inviter_user_id): + """Ask the master to get/create a guest account for given 3PID. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + + uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port) + + payload = { + "requester": requester.serialize(), + "medium": medium, + "address": address, + "inviter_user_id": inviter_user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def notify_user_membership_change(client, host, port, user_id, room_id, change): + """Notify master that a user has joined or left the room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication. + user_id (str) + room_id (str) + change (str): Either "join" or "left" + + Returns: + Deferred + """ + assert change in ("joined", "left") + + uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change) + + payload = { + "user_id": user_id, + "room_id": room_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationRemoteJoinRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_join$")] + + def __init__(self, hs): + super(ReplicationRemoteJoinRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_join: %s into room: %s", + user_id, room_id, + ) + + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user_id, + event_content, + ) + + defer.returnValue((200, {})) + + +class ReplicationRemoteRejectInviteRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")] + + def __init__(self, hs): + super(ReplicationRemoteRejectInviteRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_reject_invite: %s out of room: %s", + user_id, room_id, + ) + + try: + event = yield self.federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id, + ) + ret = event.get_pdu_json() + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + user_id, room_id + ) + ret = {} + + defer.returnValue((200, ret)) + + +class ReplicationRegister3PIDGuestRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")] + + def __init__(self, hs): + super(ReplicationRegister3PIDGuestRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + medium = content["medium"] + address = content["address"] + inviter_user_id = content["inviter_user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info("get_or_register_3pid_guest: %r", content) + + ret = yield self.registeration_handler.get_or_register_3pid_guest( + medium, address, inviter_user_id, + ) + + defer.returnValue((200, ret)) + + +class ReplicationUserJoinedLeftRoomRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")] + + def __init__(self, hs): + super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.distributor = hs.get_distributor() + + def on_POST(self, request, change): + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + room_id = content["room_id"] + + logger.info("user membership change: %s in %s", user_id, room_id) + + user = UserID.from_string(user_id) + + if change == "joined": + user_joined_room(self.distributor, user, room_id) + elif change == "left": + user_left_room(self.distributor, user, room_id) + else: + raise Exception("Unrecognized change: %r", change) + + return (200, {}) + + +def register_servlets(hs, http_server): + ReplicationRemoteJoinRestServlet(hs).register(http_server) + ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) + ReplicationRegister3PIDGuestRestServlet(hs).register(http_server) + ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py new file mode 100644 index 0000000000..a9baa2c1c3 --- /dev/null +++ b/synapse/replication/http/send_event.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import ( + SynapseError, MatrixCodeMessageException, CodeMessageException, +) +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.util.async import sleep +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.metrics import Measure +from synapse.types import Requester, UserID + +import logging +import re + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def send_event_to_master(client, host, port, requester, event, context, + ratelimit, extra_users): + """Send event to be handled on the master + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ + uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( + host, port, event.event_id, + ) + + payload = { + "event": event.get_pdu_json(), + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": context.serialize(event), + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed on + # the master, and so whether we should clean up or not. + while True: + try: + result = yield client.put_json(uri, payload) + break + except CodeMessageException as e: + if e.code != 504: + raise + + logger.warn("send_event request timed out") + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + yield sleep(1) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationSendEventRestServlet(RestServlet): + """Handles events newly created on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_event/:event_id + + { + "event": { .. serialized event .. }, + "internal_metadata": { .. serialized internal_metadata .. }, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + "extra_users": [], + } + """ + PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")] + + def __init__(self, hs): + super(ReplicationSendEventRestServlet, self).__init__() + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + # The responses are tiny, so we may as well cache them for a while + self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) + + def on_PUT(self, request, event_id): + return self.response_cache.wrap( + event_id, + self._handle_request, + request + ) + + @defer.inlineCallbacks + def _handle_request(self, request): + with Measure(self.clock, "repl_send_event_parse"): + content = parse_json_object_from_request(request) + + event_dict = content["event"] + internal_metadata = content["internal_metadata"] + rejected_reason = content["rejected_reason"] + event = FrozenEvent(event_dict, internal_metadata, rejected_reason) + + requester = Requester.deserialize(self.store, content["requester"]) + context = yield EventContext.deserialize(self.store, content["context"]) + + ratelimit = content["ratelimit"] + extra_users = [UserID.from_string(u) for u in content["extra_users"]] + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "Got event to send with ID: %s into room: %s", + event.event_id, event.room_id, + ) + + yield self.event_creation_handler.persist_and_notify_client_event( + requester, event, context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py deleted file mode 100644 index fc18130ab4..0000000000 --- a/synapse/replication/presence_resource.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - - -class PresenceResource(Resource): - """ - HTTP endpoint for marking users as syncing. - - POST /_synapse/replication/presence HTTP/1.1 - Content-Type: application/json - - { - "process_id": "<process_id>", - "syncing_users": ["<user_id>"] - } - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.clock = hs.get_clock() - self.presence_handler = hs.get_presence_handler() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - @defer.inlineCallbacks - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - process_id = content["process_id"] - syncing_user_ids = content["syncing_users"] - - yield self.presence_handler.update_external_syncs( - process_id, set(syncing_user_ids) - ) - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py deleted file mode 100644 index 9b01ab3c13..0000000000 --- a/synapse/replication/pusher_resource.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - - -class PusherResource(Resource): - """ - HTTP endpoint for deleting rejected pushers - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self.clock = hs.get_clock() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - @defer.inlineCallbacks - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - for remove in content["remove"]: - yield self.store.delete_pusher_by_app_id_pushkey_user_id( - remove["app_id"], - remove["push_key"], - remove["user_id"], - ) - - self.notifier.on_new_replication_data() - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py deleted file mode 100644 index d8eb14592b..0000000000 --- a/synapse/replication/resource.py +++ /dev/null @@ -1,576 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.servlet import parse_integer, parse_string -from synapse.http.server import request_handler, finish_request -from synapse.replication.pusher_resource import PusherResource -from synapse.replication.presence_resource import PresenceResource -from synapse.replication.expire_cache import ExpireCacheResource -from synapse.api.errors import SynapseError - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - -import ujson as json - -import collections -import logging - -logger = logging.getLogger(__name__) - -REPLICATION_PREFIX = "/_synapse/replication" - -STREAM_NAMES = ( - ("events",), - ("presence",), - ("typing",), - ("receipts",), - ("user_account_data", "room_account_data", "tag_account_data",), - ("backfill",), - ("push_rules",), - ("pushers",), - ("caches",), - ("to_device",), - ("public_rooms",), - ("federation",), - ("device_lists",), -) - - -class ReplicationResource(Resource): - """ - HTTP endpoint for extracting data from synapse. - - The streams of data returned by the endpoint are controlled by the - parameters given to the API. To return a given stream pass a query - parameter with a position in the stream to return data from or the - special value "-1" to return data from the start of the stream. - - If there is no data for any of the supplied streams after the given - position then the request will block until there is data for one - of the streams. This allows clients to long-poll this API. - - The possible streams are: - - * "streams": A special stream returing the positions of other streams. - * "events": The new events seen on the server. - * "presence": Presence updates. - * "typing": Typing updates. - * "receipts": Receipt updates. - * "user_account_data": Top-level per user account data. - * "room_account_data: Per room per user account data. - * "tag_account_data": Per room per user tags. - * "backfill": Old events that have been backfilled from other servers. - * "push_rules": Per user changes to push rules. - * "pushers": Per user changes to their pushers. - * "caches": Cache invalidations. - - The API takes two additional query parameters: - - * "timeout": How long to wait before returning an empty response. - * "limit": The maximum number of rows to return for the selected streams. - - The response is a JSON object with keys for each stream with updates. Under - each key is a JSON object with: - - * "position": The current position of the stream. - * "field_names": The names of the fields in each row. - * "rows": The updates as an array of arrays. - - There are a number of ways this API could be used: - - 1) To replicate the contents of the backing database to another database. - 2) To be notified when the contents of a shared backing database changes. - 3) To "tail" the activity happening on a server for debugging. - - In the first case the client would track all of the streams and store it's - own copy of the data. - - In the second case the client might theoretically just be able to follow - the "streams" stream to track where the other streams are. However in - practise it will probably need to get the contents of the streams in - order to expire the any in-memory caches. Whether it gets the contents - of the streams from this replication API or directly from the backing - store is a matter of taste. - - In the third case the client would use the "streams" stream to find what - streams are available and their current positions. Then it can start - long-polling this replication API for new data on those streams. - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.store = hs.get_datastore() - self.sources = hs.get_event_sources() - self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() - self.federation_sender = hs.get_federation_sender() - self.notifier = hs.notifier - self.clock = hs.get_clock() - self.config = hs.get_config() - - self.putChild("remove_pushers", PusherResource(hs)) - self.putChild("syncing_users", PresenceResource(hs)) - self.putChild("expire_cache", ExpireCacheResource(hs)) - - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - - @defer.inlineCallbacks - def current_replication_token(self): - stream_token = yield self.sources.get_current_token() - backfill_token = yield self.store.get_current_backfill_token() - push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() - pushers_token = self.store.get_pushers_stream_token() - caches_token = self.store.get_cache_stream_token() - public_rooms_token = self.store.get_current_public_room_stream_id() - federation_token = self.federation_sender.get_current_token() - device_list_token = self.store.get_device_stream_token() - - defer.returnValue(_ReplicationToken( - room_stream_token, - int(stream_token.presence_key), - int(stream_token.typing_key), - int(stream_token.receipt_key), - int(stream_token.account_data_key), - backfill_token, - push_rules_token, - pushers_token, - 0, # State stream is no longer a thing - caches_token, - int(stream_token.to_device_key), - int(public_rooms_token), - int(federation_token), - int(device_list_token), - )) - - @request_handler() - @defer.inlineCallbacks - def _async_render_GET(self, request): - limit = parse_integer(request, "limit", 100) - timeout = parse_integer(request, "timeout", 10 * 1000) - - request.setHeader(b"Content-Type", b"application/json") - - request_streams = { - name: parse_integer(request, name) - for names in STREAM_NAMES for name in names - } - request_streams["streams"] = parse_string(request, "streams") - - federation_ack = parse_integer(request, "federation_ack", None) - - def replicate(): - return self.replicate( - request_streams, limit, - federation_ack=federation_ack - ) - - writer = yield self.notifier.wait_for_replication(replicate, timeout) - result = writer.finish() - - for stream_name, stream_content in result.items(): - logger.info( - "Replicating %d rows of %s from %s -> %s", - len(stream_content["rows"]), - stream_name, - request_streams.get(stream_name), - stream_content["position"], - ) - - request.write(json.dumps(result, ensure_ascii=False)) - finish_request(request) - - @defer.inlineCallbacks - def replicate(self, request_streams, limit, federation_ack=None): - writer = _Writer() - current_token = yield self.current_replication_token() - logger.debug("Replicating up to %r", current_token) - - if limit == 0: - raise SynapseError(400, "Limit cannot be 0") - - yield self.account_data(writer, current_token, limit, request_streams) - yield self.events(writer, current_token, limit, request_streams) - # TODO: implement limit - yield self.presence(writer, current_token, request_streams) - yield self.typing(writer, current_token, request_streams) - yield self.receipts(writer, current_token, limit, request_streams) - yield self.push_rules(writer, current_token, limit, request_streams) - yield self.pushers(writer, current_token, limit, request_streams) - yield self.caches(writer, current_token, limit, request_streams) - yield self.to_device(writer, current_token, limit, request_streams) - yield self.public_rooms(writer, current_token, limit, request_streams) - yield self.device_lists(writer, current_token, limit, request_streams) - self.federation(writer, current_token, limit, request_streams, federation_ack) - self.streams(writer, current_token, request_streams) - - logger.debug("Replicated %d rows", writer.total) - defer.returnValue(writer) - - def streams(self, writer, current_token, request_streams): - request_token = request_streams.get("streams") - - streams = [] - - if request_token is not None: - if request_token == "-1": - for names, position in zip(STREAM_NAMES, current_token): - streams.extend((name, position) for name in names) - else: - items = zip( - STREAM_NAMES, - current_token, - _ReplicationToken(request_token) - ) - for names, current_id, last_id in items: - if last_id < current_id: - streams.extend((name, current_id) for name in names) - - if streams: - writer.write_header_and_rows( - "streams", streams, ("name", "position"), - position=str(current_token) - ) - - @defer.inlineCallbacks - def events(self, writer, current_token, limit, request_streams): - request_events = request_streams.get("events") - request_backfill = request_streams.get("backfill") - - if request_events is not None or request_backfill is not None: - if request_events is None: - request_events = current_token.events - if request_backfill is None: - request_backfill = current_token.backfill - - no_new_tokens = ( - request_events == current_token.events - and request_backfill == current_token.backfill - ) - if no_new_tokens: - return - - res = yield self.store.get_all_new_events( - request_backfill, request_events, - current_token.backfill, current_token.events, - limit - ) - - upto_events_token = _position_from_rows( - res.new_forward_events, current_token.events - ) - - upto_backfill_token = _position_from_rows( - res.new_backfill_events, current_token.backfill - ) - - if request_events != upto_events_token: - writer.write_header_and_rows("events", res.new_forward_events, ( - "position", "internal", "json", "state_group" - ), position=upto_events_token) - - if request_backfill != upto_backfill_token: - writer.write_header_and_rows("backfill", res.new_backfill_events, ( - "position", "internal", "json", "state_group", - ), position=upto_backfill_token) - - writer.write_header_and_rows( - "forward_ex_outliers", res.forward_ex_outliers, - ("position", "event_id", "state_group"), - ) - writer.write_header_and_rows( - "backward_ex_outliers", res.backward_ex_outliers, - ("position", "event_id", "state_group"), - ) - - @defer.inlineCallbacks - def presence(self, writer, current_token, request_streams): - current_position = current_token.presence - - request_presence = request_streams.get("presence") - - if request_presence is not None and request_presence != current_position: - presence_rows = yield self.presence_handler.get_all_presence_updates( - request_presence, current_position - ) - upto_token = _position_from_rows(presence_rows, current_position) - writer.write_header_and_rows("presence", presence_rows, ( - "position", "user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active", - ), position=upto_token) - - @defer.inlineCallbacks - def typing(self, writer, current_token, request_streams): - current_position = current_token.typing - - request_typing = request_streams.get("typing") - - if request_typing is not None and request_typing != current_position: - # If they have a higher token than current max, we can assume that - # they had been talking to a previous instance of the master. Since - # we reset the token on restart, the best (but hacky) thing we can - # do is to simply resend down all the typing notifications. - if request_typing > current_position: - request_typing = 0 - - typing_rows = yield self.typing_handler.get_all_typing_updates( - request_typing, current_position - ) - upto_token = _position_from_rows(typing_rows, current_position) - writer.write_header_and_rows("typing", typing_rows, ( - "position", "room_id", "typing" - ), position=upto_token) - - @defer.inlineCallbacks - def receipts(self, writer, current_token, limit, request_streams): - current_position = current_token.receipts - - request_receipts = request_streams.get("receipts") - - if request_receipts is not None and request_receipts != current_position: - receipts_rows = yield self.store.get_all_updated_receipts( - request_receipts, current_position, limit - ) - upto_token = _position_from_rows(receipts_rows, current_position) - writer.write_header_and_rows("receipts", receipts_rows, ( - "position", "room_id", "receipt_type", "user_id", "event_id", "data" - ), position=upto_token) - - @defer.inlineCallbacks - def account_data(self, writer, current_token, limit, request_streams): - current_position = current_token.account_data - - user_account_data = request_streams.get("user_account_data") - room_account_data = request_streams.get("room_account_data") - tag_account_data = request_streams.get("tag_account_data") - - if user_account_data is not None or room_account_data is not None: - if user_account_data is None: - user_account_data = current_position - if room_account_data is None: - room_account_data = current_position - - no_new_tokens = ( - user_account_data == current_position - and room_account_data == current_position - ) - if no_new_tokens: - return - - user_rows, room_rows = yield self.store.get_all_updated_account_data( - user_account_data, room_account_data, current_position, limit - ) - - upto_users_token = _position_from_rows(user_rows, current_position) - upto_rooms_token = _position_from_rows(room_rows, current_position) - - writer.write_header_and_rows("user_account_data", user_rows, ( - "position", "user_id", "type", "content" - ), position=upto_users_token) - writer.write_header_and_rows("room_account_data", room_rows, ( - "position", "user_id", "room_id", "type", "content" - ), position=upto_rooms_token) - - if tag_account_data is not None: - tag_rows = yield self.store.get_all_updated_tags( - tag_account_data, current_position, limit - ) - upto_tag_token = _position_from_rows(tag_rows, current_position) - writer.write_header_and_rows("tag_account_data", tag_rows, ( - "position", "user_id", "room_id", "tags" - ), position=upto_tag_token) - - @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit, request_streams): - current_position = current_token.push_rules - - push_rules = request_streams.get("push_rules") - - if push_rules is not None and push_rules != current_position: - rows = yield self.store.get_all_push_rule_updates( - push_rules, current_position, limit - ) - upto_token = _position_from_rows(rows, current_position) - writer.write_header_and_rows("push_rules", rows, ( - "position", "event_stream_ordering", "user_id", "rule_id", "op", - "priority_class", "priority", "conditions", "actions" - ), position=upto_token) - - @defer.inlineCallbacks - def pushers(self, writer, current_token, limit, request_streams): - current_position = current_token.pushers - - pushers = request_streams.get("pushers") - - if pushers is not None and pushers != current_position: - updated, deleted = yield self.store.get_all_updated_pushers( - pushers, current_position, limit - ) - upto_token = _position_from_rows(updated, current_position) - writer.write_header_and_rows("pushers", updated, ( - "position", "user_id", "access_token", "profile_tag", "kind", - "app_id", "app_display_name", "device_display_name", "pushkey", - "ts", "lang", "data" - ), position=upto_token) - writer.write_header_and_rows("deleted_pushers", deleted, ( - "position", "user_id", "app_id", "pushkey" - ), position=upto_token) - - @defer.inlineCallbacks - def caches(self, writer, current_token, limit, request_streams): - current_position = current_token.caches - - caches = request_streams.get("caches") - - if caches is not None and caches != current_position: - updated_caches = yield self.store.get_all_updated_caches( - caches, current_position, limit - ) - upto_token = _position_from_rows(updated_caches, current_position) - writer.write_header_and_rows("caches", updated_caches, ( - "position", "cache_func", "keys", "invalidation_ts" - ), position=upto_token) - - @defer.inlineCallbacks - def to_device(self, writer, current_token, limit, request_streams): - current_position = current_token.to_device - - to_device = request_streams.get("to_device") - - if to_device is not None and to_device != current_position: - to_device_rows = yield self.store.get_all_new_device_messages( - to_device, current_position, limit - ) - upto_token = _position_from_rows(to_device_rows, current_position) - writer.write_header_and_rows("to_device", to_device_rows, ( - "position", "user_id", "device_id", "message_json" - ), position=upto_token) - - @defer.inlineCallbacks - def public_rooms(self, writer, current_token, limit, request_streams): - current_position = current_token.public_rooms - - public_rooms = request_streams.get("public_rooms") - - if public_rooms is not None and public_rooms != current_position: - public_rooms_rows = yield self.store.get_all_new_public_rooms( - public_rooms, current_position, limit - ) - upto_token = _position_from_rows(public_rooms_rows, current_position) - writer.write_header_and_rows("public_rooms", public_rooms_rows, ( - "position", "room_id", "visibility", "appservice_id", "network_id", - ), position=upto_token) - - def federation(self, writer, current_token, limit, request_streams, federation_ack): - if self.config.send_federation: - return - - current_position = current_token.federation - - federation = request_streams.get("federation") - - if federation is not None and federation != current_position: - federation_rows = self.federation_sender.get_replication_rows( - federation, limit, federation_ack=federation_ack, - ) - upto_token = _position_from_rows(federation_rows, current_position) - writer.write_header_and_rows("federation", federation_rows, ( - "position", "type", "content", - ), position=upto_token) - - @defer.inlineCallbacks - def device_lists(self, writer, current_token, limit, request_streams): - current_position = current_token.device_lists - - device_lists = request_streams.get("device_lists") - - if device_lists is not None and device_lists != current_position: - changes = yield self.store.get_all_device_list_changes_for_remotes( - device_lists, - ) - writer.write_header_and_rows("device_lists", changes, ( - "position", "user_id", "destination", - ), position=current_position) - - -class _Writer(object): - """Writes the streams as a JSON object as the response to the request""" - def __init__(self): - self.streams = {} - self.total = 0 - - def write_header_and_rows(self, name, rows, fields, position=None): - if position is None: - if rows: - position = rows[-1][0] - else: - return - - self.streams[name] = { - "position": position if type(position) is int else str(position), - "field_names": fields, - "rows": rows, - } - - self.total += len(rows) - - def __nonzero__(self): - return bool(self.total) - - def finish(self): - return self.streams - - -class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( - "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", - "federation", "device_lists", -))): - __slots__ = [] - - def __new__(cls, *args): - if len(args) == 1: - streams = [int(value) for value in args[0].split("_")] - if len(streams) < len(cls._fields): - streams.extend([0] * (len(cls._fields) - len(streams))) - return cls(*streams) - else: - return super(_ReplicationToken, cls).__new__(cls, *args) - - def __str__(self): - return "_".join(str(value) for value in self) - - -def _position_from_rows(rows, current_position): - """Calculates a position to return for a stream. Ideally we want to return the - position of the last row, as that will be the most correct. However, if there - are no rows we fall back to using the current position to stop us from - repeatedly hitting the storage layer unncessarily thinking there are updates. - (Not all advances of the token correspond to an actual update) - - We can't just always return the current position, as we often limit the - number of rows we replicate, and so the stream may lag. The assumption is - that if the storage layer returns no new rows then we are not lagging and - we are at the `current_position`. - """ - if rows: - return rows[-1][0] - return current_position diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 18076e0f3b..61f5590c53 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from twisted.internet import defer from ._slaved_id_tracker import SlavedIdTracker @@ -26,7 +25,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(hs) + super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id", @@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None - self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" - self.http_client = hs.get_simple_http_client() + self.hs = hs def stream_positions(self): pos = {} @@ -43,33 +41,20 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos - def process_replication(self, result): - stream = result.get("caches") - if stream: - for row in stream["rows"]: - ( - position, cache_func, keys, invalidation_ts, - ) = row - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "caches": + self._cache_id_gen.advance(token) + for row in rows: try: - getattr(self, cache_func).invalidate(tuple(keys)) + getattr(self, row.cache_func).invalidate(tuple(row.keys)) except AttributeError: - logger.info("Got unexpected cache_func: %r", cache_func) - self._cache_id_gen.advance(int(stream["position"])) - return defer.succeed(None) + # We probably haven't pulled in the cache in this worker, + # which is fine. + pass def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys) - @defer.inlineCallbacks def _send_invalidation_poke(self, cache_func, keys): - try: - yield self.http_client.post_json_get_json(self.expire_cache_url, { - "invalidate": [{ - "name": cache_func.__name__, - "keys": list(keys), - }] - }) - except: - logger.exception("Failed to poke on expire_cache") + self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 24b5c79d4a..9d1d173b2f 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -27,4 +27,9 @@ class SlavedIdTracker(object): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): + """ + + Returns: + int + """ return self._current diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 77c64722c7..d9ba6d69b1 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,50 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.account_data import AccountDataStore -from synapse.storage.tags import TagsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.tags import TagsWorkerStore -class SlavedAccountDataStore(BaseSlavedStore): +class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_current_token(), - ) - - get_account_data_for_user = ( - AccountDataStore.__dict__["get_account_data_for_user"] - ) - - get_global_account_data_by_type_for_users = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] - ) - get_global_account_data_by_type_for_user = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] - ) - - get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] - get_tags_for_room = ( - DataStore.get_tags_for_room.__func__ - ) - get_account_data_for_room = ( - DataStore.get_account_data_for_room.__func__ - ) - - get_updated_tags = DataStore.get_updated_tags.__func__ - get_updated_account_data_for_user = ( - DataStore.get_updated_account_data_for_user.__func__ - ) + super(SlavedAccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() @@ -69,38 +40,29 @@ class SlavedAccountDataStore(BaseSlavedStore): result["tag_account_data"] = position return result - def process_replication(self, result): - stream = result.get("user_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id, data_type = row[:3] - self.get_global_account_data_by_type_for_user.invalidate( - (data_type, user_id,) - ) - self.get_account_data_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "tag_account_data": + self._account_data_id_gen.advance(token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("room_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_account_data_for_user.invalidate((user_id,)) - self._account_data_stream_cache.entity_has_changed( - user_id, position + elif stream_name == "account_data": + self._account_data_id_gen.advance(token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id,) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type,), ) - - stream = result.get("tag_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_tags_for_user.invalidate((user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - return super(SlavedAccountDataStore, self).process_replication(result) + return super(SlavedAccountDataStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index a374f2f1a2..8cae3076f4 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,28 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.config.appservice import load_appservices +from synapse.storage.appservice import ( + ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore, +) -class SlavedApplicationServiceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) - self.services_cache = load_appservices( - hs.config.server_name, - hs.config.app_service_config_files - ) - - get_app_service_by_token = DataStore.get_app_service_by_token.__func__ - get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ - get_app_services = DataStore.get_app_services.__func__ - get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ - create_appservice_txn = DataStore.create_appservice_txn.__func__ - get_appservices_by_state = DataStore.get_appservices_by_state.__func__ - get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ - _get_last_txn = DataStore._get_last_txn.__func__ - complete_appservice_txn = DataStore.complete_appservice_txn.__func__ - get_appservice_state = DataStore.get_appservice_state.__func__ - set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ - set_appservice_state = DataStore.set_appservice_state.__func__ +class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore): + pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py new file mode 100644 index 0000000000..352c9a2aa8 --- /dev/null +++ b/synapse/replication/slave/storage/client_ips.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from synapse.storage.client_ips import LAST_SEEN_GRANULARITY +from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import Cache + + +class SlavedClientIpStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedClientIpStore, self).__init__(db_conn, hs) + + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", + keylen=4, + max_entries=50000 * CACHE_SIZE_FACTOR, + ) + + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) + + try: + last_seen = self.client_ip_last_seen.get(key) + except KeyError: + last_seen = None + + # Rate-limited inserts + if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: + return + + self.hs.get_tcp_replication().send_user_ip( + user_id, access_token, ip, user_agent, device_id, now + ) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index cc860f9f9b..6f3fb64770 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -17,6 +17,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.expiringcache import ExpiringCache class SlavedDeviceInboxStore(BaseSlavedStore): @@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore): self._device_inbox_id_gen.get_current_token() ) + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ @@ -45,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore): result["to_device"] = self._device_inbox_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("to_device") - if stream: - self._device_inbox_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - entity = row[1] - - if entity.startswith("@"): + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "to_device": + self._device_inbox_id_gen.advance(token) + for row in rows: + if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) - - return super(SlavedDeviceInboxStore, self).process_replication(result) + return super(SlavedDeviceInboxStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ca46aa17b6..7687867aee 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,6 +16,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore +from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -45,28 +46,25 @@ class SlavedDeviceStore(BaseSlavedStore): _mark_as_sent_devices_by_remote_txn = ( DataStore._mark_as_sent_devices_by_remote_txn.__func__ ) + count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("device_lists") - if stream: - self._device_list_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - user_id = row[1] - destination = row[2] - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "device_lists": + self._device_list_id_gen.advance(token) + for row in rows: self._device_list_stream_cache.entity_has_changed( - user_id, stream_id + row.user_id, token ) - if destination: + if row.destination: self._device_list_federation_stream_cache.entity_has_changed( - destination, stream_id + row.destination, token ) - - return super(SlavedDeviceStore, self).process_replication(result) + return super(SlavedDeviceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 7301d885f2..6deecd3963 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -14,10 +14,8 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage.directory import DirectoryStore +from synapse.storage.directory import DirectoryWorkerStore -class DirectoryStore(BaseSlavedStore): - get_aliases_for_room = DirectoryStore.__dict__[ - "get_aliases_for_room" - ] +class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d72ff6055c..b1f64ef0d8 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - -from synapse.api.constants import EventTypes -from synapse.events import FrozenEvent -from synapse.storage import DataStore -from synapse.storage.roommember import RoomMemberStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.state import StateStore -from synapse.storage.stream import StreamStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -import ujson as json import logging +from synapse.api.constants import EventTypes +from synapse.storage.event_federation import EventFederationWorkerStore +from synapse.storage.event_push_actions import EventPushActionsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore +from synapse.storage.state import StateGroupWorkerStore +from synapse.storage.stream import StreamWorkerStore +from synapse.storage.signatures import SignatureWorkerStore +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -41,152 +38,33 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(BaseSlavedStore): +class SlavedEventStore(EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - self.stream_ordering_month_ago = 0 - self._stream_order_on_start = self.get_room_max_stream_ordering() + super(SlavedEventStore, self).__init__(db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. - get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] - get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] - get_users_who_share_room_with_user = ( - RoomMemberStore.__dict__["get_users_who_share_room_with_user"] - ) - get_latest_event_ids_in_room = EventFederationStore.__dict__[ - "get_latest_event_ids_in_room" - ] - get_invited_rooms_for_user = RoomMemberStore.__dict__[ - "get_invited_rooms_for_user" - ] - get_unread_event_push_actions_by_room_for_user = ( - EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] - ) - _get_state_group_for_events = ( - StateStore.__dict__["_get_state_group_for_events"] - ) - _get_state_group_for_event = ( - StateStore.__dict__["_get_state_group_for_event"] - ) - _get_state_groups_from_groups = ( - StateStore.__dict__["_get_state_groups_from_groups"] - ) - _get_state_groups_from_groups_txn = ( - DataStore._get_state_groups_from_groups_txn.__func__ - ) - _get_state_group_from_group = ( - StateStore.__dict__["_get_state_group_from_group"] - ) - get_recent_event_ids_for_room = ( - StreamStore.__dict__["get_recent_event_ids_for_room"] - ) - - get_unread_push_actions_for_user_in_range_for_http = ( - DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ - ) - get_unread_push_actions_for_user_in_range_for_email = ( - DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ - ) - get_push_action_users_in_range = ( - DataStore.get_push_action_users_in_range.__func__ - ) - get_event = DataStore.get_event.__func__ - get_events = DataStore.get_events.__func__ - get_rooms_for_user_where_membership_is = ( - DataStore.get_rooms_for_user_where_membership_is.__func__ - ) - get_membership_changes_for_user = ( - DataStore.get_membership_changes_for_user.__func__ - ) - get_room_events_max_id = DataStore.get_room_events_max_id.__func__ - get_room_events_stream_for_room = ( - DataStore.get_room_events_stream_for_room.__func__ - ) - get_events_around = DataStore.get_events_around.__func__ - get_state_for_event = DataStore.get_state_for_event.__func__ - get_state_for_events = DataStore.get_state_for_events.__func__ - get_state_groups = DataStore.get_state_groups.__func__ - get_state_groups_ids = DataStore.get_state_groups_ids.__func__ - get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ - get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ - get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ - get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ - _get_joined_users_from_context = ( - RoomMemberStore.__dict__["_get_joined_users_from_context"] - ) - - get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ - get_room_events_stream_for_rooms = ( - DataStore.get_room_events_stream_for_rooms.__func__ - ) - is_host_joined = DataStore.is_host_joined.__func__ - _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] - get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ - - _set_before_and_after = staticmethod(DataStore._set_before_and_after) - - _get_events = DataStore._get_events.__func__ - _get_events_from_cache = DataStore._get_events_from_cache.__func__ - _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ - _enqueue_events = DataStore._enqueue_events.__func__ - _do_fetch = DataStore._do_fetch.__func__ - _fetch_event_rows = DataStore._fetch_event_rows.__func__ - _get_event_from_row = DataStore._get_event_from_row.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) - _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ - _get_state_for_groups = DataStore._get_state_for_groups.__func__ - _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ - _get_events_around_txn = DataStore._get_events_around_txn.__func__ - _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() - get_backfill_events = DataStore.get_backfill_events.__func__ - _get_backfill_events = DataStore._get_backfill_events.__func__ - get_missing_events = DataStore.get_missing_events.__func__ - _get_missing_events = DataStore._get_missing_events.__func__ - - get_auth_chain = DataStore.get_auth_chain.__func__ - get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ - _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ - - get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ - - get_forward_extremeties_for_room = ( - DataStore.get_forward_extremeties_for_room.__func__ - ) - _get_forward_extremeties_for_room = ( - EventFederationStore.__dict__["_get_forward_extremeties_for_room"] - ) - - get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ - - get_federation_out_pos = DataStore.get_federation_out_pos.__func__ - update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() @@ -194,84 +72,47 @@ class SlavedEventStore(BaseSlavedStore): result["backfill"] = -self._backfill_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("events") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - - if stream["rows"]: - logger.info("Got %d event rows", len(stream["rows"])) - - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=False, + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_event( + token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=False, ) - - stream = result.get("backfill") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=True, + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + for row in rows: + self.invalidate_caches_for_event( + -token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=True, ) - - stream = result.get("forward_ex_outliers") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - stream = result.get("backward_ex_outliers") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - return super(SlavedEventStore, self).process_replication(result) - - def _process_replication_row(self, row, backfilled): - internal = json.loads(row[1]) - event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) - self.invalidate_caches_for_event( - event, backfilled, + return super(SlavedEventStore, self).process_replication_rows( + stream_name, token, rows ) - def invalidate_caches_for_event(self, event, backfilled): - self._invalidate_get_event_cache(event.event_id) + def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, + etype, state_key, redacts, backfilled): + self._invalidate_get_event_cache(event_id) - self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + self.get_latest_event_ids_in_room.invalidate((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - (event.room_id,) + (room_id,) ) if not backfilled: self._events_stream_cache.entity_has_changed( - event.room_id, event.internal_metadata.stream_ordering + room_id, stream_ordering ) - # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - # (event.room_id,) - # ) + if redacts: + self._invalidate_get_event_cache(redacts) - if event.type == EventTypes.Redaction: - self._invalidate_get_event_cache(event.redacts) - - if event.type == EventTypes.Member: + if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed( - event.state_key, event.internal_metadata.stream_ordering + state_key, stream_ordering ) - self.get_invited_rooms_for_user.invalidate((event.state_key,)) - - if not event.is_state(): - return - - if backfilled: - return - - if (not event.internal_metadata.is_invite_from_remote() - and event.internal_metadata.is_outlier()): - return + self.get_invited_rooms_for_user.invalidate((state_key,)) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py new file mode 100644 index 0000000000..0bc4bce5b0 --- /dev/null +++ b/synapse/replication/slave/storage/groups.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + + +class SlavedGroupServerStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._group_updates_id_gen = SlavedIdTracker( + db_conn, "local_group_updates", "stream_id", + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + ) + + get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ + get_group_stream_token = DataStore.get_group_stream_token.__func__ + get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + + def stream_positions(self): + result = super(SlavedGroupServerStore, self).stream_positions() + result["groups"] = self._group_updates_id_gen.get_current_token() + return result + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "groups": + self._group_updates_id_gen.advance(token) + for row in rows: + self._group_updates_stream_cache.entity_has_changed( + row.user_id, token + ) + + return super(SlavedGroupServerStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 40f6c9a386..cfb9280181 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -39,6 +39,16 @@ class SlavedPresenceStore(BaseSlavedStore): _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] + # XXX: This is a bit broken because we don't persist the accepted list in a + # way that can be replicated. This means that we don't have a way to + # invalidate the cache correctly. + get_presence_list_accepted = PresenceStore.__dict__[ + "get_presence_list_accepted" + ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + def get_current_presence_token(self): return self._presence_id_gen.get_current_token() @@ -48,14 +58,14 @@ class SlavedPresenceStore(BaseSlavedStore): result["presence"] = position return result - def process_replication(self, result): - stream = result.get("presence") - if stream: - self._presence_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "presence": + self._presence_id_gen.advance(token) + for row in rows: self.presence_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - return super(SlavedPresenceStore, self).process_replication(result) + self._get_presence_for_user.invalidate((row.user_id,)) + return super(SlavedPresenceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py new file mode 100644 index 0000000000..46c28d4171 --- /dev/null +++ b/synapse/replication/slave/storage/profile.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.storage.profile import ProfileWorkerStore + + +class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 21ceb0213a..bb2c40b6e3 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,29 +16,15 @@ from .events import SlavedEventStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.push_rule import PushRulesWorkerStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,23 +32,23 @@ class SlavedPushRuleStore(SlavedEventStore): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("push_rules") - if stream: - for row in stream["rows"]: - position = row[0] - user_id = row[2] - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "push_rules": + self._push_rules_stream_id_gen.advance(token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - self._push_rules_stream_id_gen.advance(int(stream["position"])) - - return super(SlavedPushRuleStore, self).process_replication(result) + return super(SlavedPushRuleStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index d88206b3bb..a7cd5a7291 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +17,10 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore +from synapse.storage.pusher import PusherWorkerStore -class SlavedPusherStore(BaseSlavedStore): +class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) @@ -28,25 +29,14 @@ class SlavedPusherStore(BaseSlavedStore): extra_tables=[("deleted_pushers", "stream_id")], ) - get_all_pushers = DataStore.get_all_pushers.__func__ - get_pushers_by = DataStore.get_pushers_by.__func__ - get_pushers_by_app_id_and_pushkey = ( - DataStore.get_pushers_by_app_id_and_pushkey.__func__ - ) - _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ - def stream_positions(self): result = super(SlavedPusherStore, self).stream_positions() result["pushers"] = self._pushers_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - stream = result.get("deleted_pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - return super(SlavedPusherStore, self).process_replication(result) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "pushers": + self._pushers_id_gen.advance(token) + return super(SlavedPusherStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ac9662d399..1647072f65 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +17,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.receipts import ReceiptsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.receipts import ReceiptsWorkerStore # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -29,56 +28,43 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache # the method descriptor on the DataStore and chuck them into our class. -class SlavedReceiptsStore(BaseSlavedStore): +class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedReceiptsStore, self).__init__(db_conn, hs) - + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() - ) - - get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] - get_linearized_receipts_for_room = ( - ReceiptsStore.__dict__["get_linearized_receipts_for_room"] - ) - _get_linearized_receipts_for_rooms = ( - ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] - ) - get_last_receipt_event_id_for_user = ( - ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] - ) - - get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ - get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + super(SlavedReceiptsStore, self).__init__(db_conn, hs) - get_linearized_receipts_for_rooms = ( - DataStore.get_linearized_receipts_for_rooms.__func__ - ) + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() result["receipts"] = self._receipts_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("receipts") - if stream: - self._receipts_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, room_id, receipt_type, user_id = row[:4] - self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) - self._receipts_stream_cache.entity_has_changed(room_id, position) - - return super(SlavedReceiptsStore, self).process_replication(result) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "receipts": + self._receipts_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) + + return super(SlavedReceiptsStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index e27c7332d2..7323bf0f1e 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -14,20 +14,8 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.registration import RegistrationStore +from synapse.storage.registration import RegistrationWorkerStore -class SlavedRegistrationStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedRegistrationStore, self).__init__(db_conn, hs) - - # TODO: use the cached version and invalidate deleted tokens - get_user_by_access_token = RegistrationStore.__dict__[ - "get_user_by_access_token" - ] - - _query_for_auth = DataStore._query_for_auth.__func__ - get_user_by_id = RegistrationStore.__dict__[ - "get_user_by_id" - ] +class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 6df9a25ef3..5ae1670157 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,41 +14,29 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.room import RoomStore +from synapse.storage.room import RoomWorkerStore from ._slaved_id_tracker import SlavedIdTracker -class RoomStore(BaseSlavedStore): +class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(RoomStore, self).__init__(db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) - get_public_room_ids = DataStore.get_public_room_ids.__func__ - get_current_public_room_stream_id = ( - DataStore.get_current_public_room_stream_id.__func__ - ) - get_public_room_ids_at_stream_id = ( - RoomStore.__dict__["get_public_room_ids_at_stream_id"] - ) - get_public_room_ids_at_stream_id_txn = ( - DataStore.get_public_room_ids_at_stream_id_txn.__func__ - ) - get_published_at_stream_id_txn = ( - DataStore.get_published_at_stream_id_txn.__func__ - ) - get_public_room_changes = DataStore.get_public_room_changes.__func__ + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() def stream_positions(self): result = super(RoomStore, self).stream_positions() result["public_rooms"] = self._public_room_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("public_rooms") - if stream: - self._public_room_id_gen.advance(int(stream["position"])) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "public_rooms": + self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication(result) + return super(RoomStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py new file mode 100644 index 0000000000..81c2ea7ee9 --- /dev/null +++ b/synapse/replication/tcp/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements the TCP replication protocol used by synapse to +communicate between the master process and its workers (when they're enabled). + +Further details can be found in docs/tcp_replication.rst + + +Structure of the module: + * client.py - the client classes used for workers to connect to master + * command.py - the definitions of all the valid commands + * protocol.py - contains bot the client and server protocol implementations, + these should not be used directly + * resource.py - the server classes that accepts and handle client connections + * streams.py - the definitons of all the valid streams + +""" diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py new file mode 100644 index 0000000000..6d2513c4e2 --- /dev/null +++ b/synapse/replication/tcp/client.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +from twisted.internet import reactor, defer +from twisted.internet.protocol import ReconnectingClientFactory + +from .commands import ( + FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, + UserIpCommand, +) +from .protocol import ClientReplicationStreamProtocol + +import logging + +logger = logging.getLogger(__name__) + + +class ReplicationClientFactory(ReconnectingClientFactory): + """Factory for building connections to the master. Will reconnect if the + connection is lost. + + Accepts a handler that will be called when new data is available or data + is required. + """ + maxDelay = 5 # Try at least once every N seconds + + def __init__(self, hs, client_name, handler): + self.client_name = client_name + self.handler = handler + self.server_name = hs.config.server_name + self._clock = hs.get_clock() # As self.clock is defined in super class + + reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + + def startedConnecting(self, connector): + logger.info("Connecting to replication: %r", connector.getDestination()) + + def buildProtocol(self, addr): + logger.info("Connected to replication: %r", addr) + self.resetDelay() + return ClientReplicationStreamProtocol( + self.client_name, self.server_name, self._clock, self.handler + ) + + def clientConnectionLost(self, connector, reason): + logger.error("Lost replication conn: %r", reason) + ReconnectingClientFactory.clientConnectionLost(self, connector, reason) + + def clientConnectionFailed(self, connector, reason): + logger.error("Failed to connect to replication: %r", reason) + ReconnectingClientFactory.clientConnectionFailed( + self, connector, reason + ) + + +class ReplicationClientHandler(object): + """A base handler that can be passed to the ReplicationClientFactory. + + By default proxies incoming replication data to the SlaveStore. + """ + def __init__(self, store): + self.store = store + + # The current connection. None if we are currently (re)connecting + self.connection = None + + # Any pending commands to be sent once a new connection has been + # established + self.pending_commands = [] + + # Map from string -> deferred, to wake up when receiveing a SYNC with + # the given string. + # Used for tests. + self.awaiting_syncs = {} + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + reactor.connectTCP(host, port, factory) + + def on_rdata(self, stream_name, token, rows): + """Called when we get new replication data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + logger.info("Received rdata %s -> %s", stream_name, token) + self.store.process_replication_rows(stream_name, token, rows) + + def on_position(self, stream_name, token): + """Called when we get new position data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + self.store.process_replication_rows(stream_name, token, []) + + def on_sync(self, data): + """When we received a SYNC we wake up any deferreds that were waiting + for the sync with the given data. + + Used by tests. + """ + d = self.awaiting_syncs.pop(data, None) + if d: + d.callback(data) + + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns a dictionary of stream name to token. + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. (Overriden by the synchrotron's only) + """ + return [] + + def send_command(self, cmd): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self.connection: + self.connection.send_command(cmd) + else: + logger.warn("Queuing command as not connected: %r", cmd.NAME) + self.pending_commands.append(cmd) + + def send_federation_ack(self, token): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + + def send_remove_pusher(self, app_id, push_key, user_id): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func, keys): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def await_sync(self, data): + """Returns a deferred that is resolved when we receive a SYNC command + with given data. + + Used by tests. + """ + return self.awaiting_syncs.setdefault(data, defer.Deferred()) + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self.connection = connection + if connection: + for cmd in self.pending_commands: + connection.send_command(cmd) + self.pending_commands = [] diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py new file mode 100644 index 0000000000..12aac3cc6b --- /dev/null +++ b/synapse/replication/tcp/commands.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the various valid commands + +The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are +allowed to be sent by which side. +""" + +import logging +import simplejson + + +logger = logging.getLogger(__name__) + +_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False) + + +class Command(object): + """The base command class. + + All subclasses must set the NAME variable which equates to the name of the + command on the wire. + + A full command line on the wire is constructed from `NAME + " " + to_line()` + + The default implementation creates a command of form `<NAME> <data>` + """ + NAME = None + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + """Deserialises a line from the wire into this command. `line` does not + include the command. + """ + return cls(line) + + def to_line(self): + """Serialises the comamnd for the wire. Does not include the command + prefix. + """ + return self.data + + +class ServerCommand(Command): + """Sent by the server on new connection and includes the server_name. + + Format:: + + SERVER <server_name> + """ + NAME = "SERVER" + + +class RdataCommand(Command): + """Sent by server when a subscribed stream has an update. + + Format:: + + RDATA <stream_name> <token> <row_json> + + The `<token>` may either be a numeric stream id OR "batch". The latter case + is used to support sending multiple updates with the same stream ID. This + is done by sending an RDATA for each row, with all but the last RDATA having + a token of "batch" and the last having the final stream ID. + + The client should batch all incoming RDATA with a token of "batch" (per + stream_name) until it sees an RDATA with a numeric stream ID. + + `<token>` of "batch" maps to the instance variable `token` being None. + + An example of a batched series of RDATA:: + + RDATA presence batch ["@foo:example.com", "online", ...] + RDATA presence batch ["@bar:example.com", "online", ...] + RDATA presence 59 ["@baz:example.com", "online", ...] + """ + NAME = "RDATA" + + def __init__(self, stream_name, token, row): + self.stream_name = stream_name + self.token = token + self.row = row + + @classmethod + def from_line(cls, line): + stream_name, token, row_json = line.split(" ", 2) + return cls( + stream_name, + None if token == "batch" else int(token), + simplejson.loads(row_json) + ) + + def to_line(self): + return " ".join(( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + )) + + +class PositionCommand(Command): + """Sent by the client to tell the client the stream postition without + needing to send an RDATA. + """ + NAME = "POSITION" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + return cls(stream_name, int(token)) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class ErrorCommand(Command): + """Sent by either side if there was an ERROR. The data is a string describing + the error. + """ + NAME = "ERROR" + + +class PingCommand(Command): + """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """ + NAME = "PING" + + +class NameCommand(Command): + """Sent by client to inform the server of the client's identity. The data + is the name + """ + NAME = "NAME" + + +class ReplicateCommand(Command): + """Sent by the client to subscribe to the stream. + + Format:: + + REPLICATE <stream_name> <token> + + Where <token> may be either: + * a numeric stream_id to stream updates from + * "NOW" to stream all subsequent updates. + + The <stream_name> can be "ALL" to subscribe to all known streams, in which + case the <token> must be set to "NOW", i.e.:: + + REPLICATE ALL NOW + """ + NAME = "REPLICATE" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + if token in ("NOW", "now"): + token = "NOW" + else: + token = int(token) + return cls(stream_name, token) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class UserSyncCommand(Command): + """Sent by the client to inform the server that a user has started or + stopped syncing. Used to calculate presence on the master. + + Includes a timestamp of when the last user sync was. + + Format:: + + USER_SYNC <user_id> <state> <last_sync_ms> + + Where <state> is either "start" or "stop" + """ + NAME = "USER_SYNC" + + def __init__(self, user_id, is_syncing, last_sync_ms): + self.user_id = user_id + self.is_syncing = is_syncing + self.last_sync_ms = last_sync_ms + + @classmethod + def from_line(cls, line): + user_id, state, last_sync_ms = line.split(" ", 2) + + if state not in ("start", "end"): + raise Exception("Invalid USER_SYNC state %r" % (state,)) + + return cls(user_id, state == "start", int(last_sync_ms)) + + def to_line(self): + return " ".join(( + self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), + )) + + +class FederationAckCommand(Command): + """Sent by the client when it has processed up to a given point in the + federation stream. This allows the master to drop in-memory caches of the + federation stream. + + This must only be sent from one worker (i.e. the one sending federation) + + Format:: + + FEDERATION_ACK <token> + """ + NAME = "FEDERATION_ACK" + + def __init__(self, token): + self.token = token + + @classmethod + def from_line(cls, line): + return cls(int(line)) + + def to_line(self): + return str(self.token) + + +class SyncCommand(Command): + """Used for testing. The client protocol implementation allows waiting + on a SYNC command with a specified data. + """ + NAME = "SYNC" + + +class RemovePusherCommand(Command): + """Sent by the client to request the master remove the given pusher. + + Format:: + + REMOVE_PUSHER <app_id> <push_key> <user_id> + """ + NAME = "REMOVE_PUSHER" + + def __init__(self, app_id, push_key, user_id): + self.user_id = user_id + self.app_id = app_id + self.push_key = push_key + + @classmethod + def from_line(cls, line): + app_id, push_key, user_id = line.split(" ", 2) + + return cls(app_id, push_key, user_id) + + def to_line(self): + return " ".join((self.app_id, self.push_key, self.user_id)) + + +class InvalidateCacheCommand(Command): + """Sent by the client to invalidate an upstream cache. + + THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE + NOT DISASTROUS IF WE DROP ON THE FLOOR. + + Mainly used to invalidate destination retry timing caches. + + Format:: + + INVALIDATE_CACHE <cache_func> <keys_json> + + Where <keys_json> is a json list. + """ + NAME = "INVALIDATE_CACHE" + + def __init__(self, cache_func, keys): + self.cache_func = cache_func + self.keys = keys + + @classmethod + def from_line(cls, line): + cache_func, keys_json = line.split(" ", 1) + + return cls(cache_func, simplejson.loads(keys_json)) + + def to_line(self): + return " ".join(( + self.cache_func, _json_encoder.encode(self.keys), + )) + + +class UserIpCommand(Command): + """Sent periodically when a worker sees activity from a client. + + Format:: + + USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent> + """ + NAME = "USER_IP" + + def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): + self.user_id = user_id + self.access_token = access_token + self.ip = ip + self.user_agent = user_agent + self.device_id = device_id + self.last_seen = last_seen + + @classmethod + def from_line(cls, line): + user_id, jsn = line.split(" ", 1) + + access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn) + + return cls( + user_id, access_token, ip, user_agent, device_id, last_seen + ) + + def to_line(self): + return self.user_id + " " + _json_encoder.encode(( + self.access_token, self.ip, self.user_agent, self.device_id, + self.last_seen, + )) + + +# Map of command name to command type. +COMMAND_MAP = { + cmd.NAME: cmd + for cmd in ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + UserIpCommand, + ) +} + +# The commands the server is allowed to send +VALID_SERVER_COMMANDS = ( + ServerCommand.NAME, + RdataCommand.NAME, + PositionCommand.NAME, + ErrorCommand.NAME, + PingCommand.NAME, + SyncCommand.NAME, +) + +# The commands the client is allowed to send +VALID_CLIENT_COMMANDS = ( + NameCommand.NAME, + ReplicateCommand.NAME, + PingCommand.NAME, + UserSyncCommand.NAME, + FederationAckCommand.NAME, + RemovePusherCommand.NAME, + InvalidateCacheCommand.NAME, + UserIpCommand.NAME, + ErrorCommand.NAME, +) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py new file mode 100644 index 0000000000..d7d38464b2 --- /dev/null +++ b/synapse/replication/tcp/protocol.py @@ -0,0 +1,655 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the implementation of both the client and server +protocols. + +The basic structure of the protocol is line based, where the initial word of +each line specifies the command. The rest of the line is parsed based on the +command. For example, the `RDATA` command is defined as:: + + RDATA <stream_name> <token> <row_json> + +(Note that `<row_json>` may contains spaces, but cannot contain newlines.) + +Blank lines are ignored. + +# Example + +An example iteraction is shown below. Each line is prefixed with '>' or '<' to +indicate which side is sending, these are *not* included on the wire:: + + * connection established * + > SERVER localhost:8823 + > PING 1490197665618 + < NAME synapse.app.appservice + < PING 1490197665618 + < REPLICATE events 1 + < REPLICATE backfill 1 + < REPLICATE caches 1 + > POSITION events 1 + > POSITION backfill 1 + > POSITION caches 1 + > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + < PING 1490197675618 + > ERROR server stopping + * connection closed by server * +""" + +from twisted.internet import defer +from twisted.protocols.basic import LineOnlyReceiver +from twisted.python.failure import Failure + +from .commands import ( + COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, + NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, +) +from .streams import STREAMS_MAP + +from synapse.util.stringutils import random_string +from synapse.metrics.metric import CounterMetric + +import logging +import synapse.metrics +import struct +import fcntl + + +metrics = synapse.metrics.get_metrics_for(__name__) + +connection_close_counter = metrics.register_counter( + "close_reason", labels=["reason_type"], +) + + +# A list of all connected protocols. This allows us to send metrics about the +# connections. +connected_connections = [] + + +logger = logging.getLogger(__name__) + + +PING_TIME = 5000 +PING_TIMEOUT_MULTIPLIER = 5 +PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER + + +class ConnectionStates(object): + CONNECTING = "connecting" + ESTABLISHED = "established" + PAUSED = "paused" + CLOSED = "closed" + + +class BaseReplicationStreamProtocol(LineOnlyReceiver): + """Base replication protocol shared between client and server. + + Reads lines (ignoring blank ones) and parses them into command classes, + asserting that they are valid for the given direction, i.e. server commands + are only sent by the server. + + On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed + command. + + It also sends `PING` periodically, and correctly times out remote connections + (if they send a `PING` command) + """ + delimiter = b'\n' + + VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive + VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + + max_line_buffer = 10000 + + def __init__(self, clock): + self.clock = clock + + self.last_received_command = self.clock.time_msec() + self.last_sent_command = 0 + self.time_we_closed = None # When we requested the connection be closed + + self.received_ping = False # Have we reecived a ping from the other side + + self.state = ConnectionStates.CONNECTING + + self.name = "anon" # The name sent by a client. + self.conn_id = random_string(5) # To dedupe in case of name clashes. + + # List of pending commands to send once we've established the connection + self.pending_commands = [] + + # The LoopingCall for sending pings. + self._send_ping_loop = None + + self.inbound_commands_counter = CounterMetric( + "inbound_commands", labels=["command"], + ) + self.outbound_commands_counter = CounterMetric( + "outbound_commands", labels=["command"], + ) + + def connectionMade(self): + logger.info("[%s] Connection established", self.id()) + + self.state = ConnectionStates.ESTABLISHED + + connected_connections.append(self) # Register connection for metrics + + self.transport.registerProducer(self, True) # For the *Producing callbacks + + self._send_pending_commands() + + # Starts sending pings + self._send_ping_loop = self.clock.looping_call(self.send_ping, 5000) + + # Always send the initial PING so that the other side knows that they + # can time us out. + self.send_command(PingCommand(self.clock.time_msec())) + + def send_ping(self): + """Periodically sends a ping and checks if we should close the connection + due to the other side timing out. + """ + now = self.clock.time_msec() + + if self.time_we_closed: + if now - self.time_we_closed > PING_TIMEOUT_MS: + logger.info( + "[%s] Failed to close connection gracefully, aborting", self.id() + ) + self.transport.abortConnection() + else: + if now - self.last_sent_command >= PING_TIME: + self.send_command(PingCommand(now)) + + if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + logger.info( + "[%s] Connection hasn't received command in %r ms. Closing.", + self.id(), now - self.last_received_command + ) + self.send_error("ping timeout") + + def lineReceived(self, line): + """Called when we've received a line + """ + if line.strip() == "": + # Ignore blank lines + return + + line = line.decode("utf-8") + cmd_name, rest_of_line = line.split(" ", 1) + + if cmd_name not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd_name) + self.send_error("invalid command: %s", cmd_name) + return + + self.last_received_command = self.clock.time_msec() + + self.inbound_commands_counter.inc(cmd_name) + + cmd_cls = COMMAND_MAP[cmd_name] + try: + cmd = cmd_cls.from_line(rest_of_line) + except Exception as e: + logger.exception( + "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line + ) + self.send_error( + "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) + ) + return + + # Now lets try and call on_<CMD_NAME> function + try: + getattr(self, "on_%s" % (cmd_name,))(cmd) + except Exception: + logger.exception("[%s] Failed to handle line: %r", self.id(), line) + + def close(self): + logger.warn("[%s] Closing connection", self.id()) + self.time_we_closed = self.clock.time_msec() + self.transport.loseConnection() + self.on_connection_closed() + + def send_error(self, error_string, *args): + """Send an error to remote and close the connection. + """ + self.send_command(ErrorCommand(error_string % args)) + self.close() + + def send_command(self, cmd, do_buffer=True): + """Send a command if connection has been established. + + Args: + cmd (Command) + do_buffer (bool): Whether to buffer the message or always attempt + to send the command. This is mostly used to send an error + message if we're about to close the connection due our buffers + becoming full. + """ + if self.state == ConnectionStates.CLOSED: + logger.debug("[%s] Not sending, connection closed", self.id()) + return + + if do_buffer and self.state != ConnectionStates.ESTABLISHED: + self._queue_command(cmd) + return + + self.outbound_commands_counter.inc(cmd.NAME) + + string = "%s %s" % (cmd.NAME, cmd.to_line(),) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + self.sendLine(string.encode("utf-8")) + + self.last_sent_command = self.clock.time_msec() + + def _queue_command(self, cmd): + """Queue the command until the connection is ready to write to again. + """ + logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + self.pending_commands.append(cmd) + + if len(self.pending_commands) > self.max_line_buffer: + # The other side is failing to keep up and out buffers are becoming + # full, so lets close the connection. + # XXX: should we squawk more loudly? + logger.error("[%s] Remote failed to keep up", self.id()) + self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) + self.close() + + def _send_pending_commands(self): + """Send any queued commandes + """ + pending = self.pending_commands + self.pending_commands = [] + for cmd in pending: + self.send_command(cmd) + + def on_PING(self, line): + self.received_ping = True + + def on_ERROR(self, cmd): + logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) + + def pauseProducing(self): + """This is called when both the kernel send buffer and the twisted + tcp connection send buffers have become full. + + We don't actually have any control over those sizes, so we buffer some + commands ourselves before knifing the connection due to the remote + failing to keep up. + """ + logger.info("[%s] Pause producing", self.id()) + self.state = ConnectionStates.PAUSED + + def resumeProducing(self): + """The remote has caught up after we started buffering! + """ + logger.info("[%s] Resume producing", self.id()) + self.state = ConnectionStates.ESTABLISHED + self._send_pending_commands() + + def stopProducing(self): + """We're never going to send any more data (normally because either + we or the remote has closed the connection) + """ + logger.info("[%s] Stop producing", self.id()) + self.on_connection_closed() + + def connectionLost(self, reason): + logger.info("[%s] Replication connection closed: %r", self.id(), reason) + if isinstance(reason, Failure): + connection_close_counter.inc(reason.type.__name__) + else: + connection_close_counter.inc(reason.__class__.__name__) + + try: + # Remove us from list of connections to be monitored + connected_connections.remove(self) + except ValueError: + pass + + # Stop the looping call sending pings. + if self._send_ping_loop and self._send_ping_loop.running: + self._send_ping_loop.stop() + + self.on_connection_closed() + + def on_connection_closed(self): + logger.info("[%s] Connection was closed", self.id()) + + self.state = ConnectionStates.CLOSED + self.pending_commands = [] + + if self.transport: + self.transport.unregisterProducer() + + def __str__(self): + return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % ( + self.name, self.conn_id, self.addr, + ) + + def id(self): + return "%s-%s" % (self.name, self.conn_id) + + +class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS + + def __init__(self, server_name, clock, streamer, addr): + BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + + self.server_name = server_name + self.streamer = streamer + self.addr = addr + + # The streams the client has subscribed to and is up to date with + self.replication_streams = set() + + # The streams the client is currently subscribing to. + self.connecting_streams = set() + + # Map from stream name to list of updates to send once we've finished + # subscribing the client to the stream. + self.pending_rdata = {} + + def connectionMade(self): + self.send_command(ServerCommand(self.server_name)) + BaseReplicationStreamProtocol.connectionMade(self) + self.streamer.new_connection(self) + + def on_NAME(self, cmd): + logger.info("[%s] Renamed to %r", self.id(), cmd.data) + self.name = cmd.data + + def on_USER_SYNC(self, cmd): + self.streamer.on_user_sync( + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + ) + + def on_REPLICATE(self, cmd): + stream_name = cmd.stream_name + token = cmd.token + + if stream_name == "ALL": + # Subscribe to all streams we're publishing to. + for stream in self.streamer.streams_by_name.iterkeys(): + self.subscribe_to_stream(stream, token) + else: + self.subscribe_to_stream(stream_name, token) + + def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) + + def on_REMOVE_PUSHER(self, cmd): + self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + + def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + + def on_USER_IP(self, cmd): + self.streamer.on_user_ip( + cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + cmd.last_seen, + ) + + @defer.inlineCallbacks + def subscribe_to_stream(self, stream_name, token): + """Subscribe the remote to a streams. + + This invloves checking if they've missed anything and sending those + updates down if they have. During that time new updates for the stream + are queued and sent once we've sent down any missed updates. + """ + self.replication_streams.discard(stream_name) + self.connecting_streams.add(stream_name) + + try: + # Get missing updates + updates, current_token = yield self.streamer.get_stream_updates( + stream_name, token, + ) + + # Send all the missing updates + for update in updates: + token, row = update[0], update[1] + self.send_command(RdataCommand(stream_name, token, row)) + + # We send a POSITION command to ensure that they have an up to + # date token (especially useful if we didn't send any updates + # above) + self.send_command(PositionCommand(stream_name, current_token)) + + # Now we can send any updates that came in while we were subscribing + pending_rdata = self.pending_rdata.pop(stream_name, []) + for token, update in pending_rdata: + # Only send updates newer than the current token + if token > current_token: + self.send_command(RdataCommand(stream_name, token, update)) + + # They're now fully subscribed + self.replication_streams.add(stream_name) + except Exception as e: + logger.exception("[%s] Failed to handle REPLICATE command", self.id()) + self.send_error("failed to handle replicate: %r", e) + finally: + self.connecting_streams.discard(stream_name) + + def stream_update(self, stream_name, token, data): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + if stream_name in self.replication_streams: + # The client is subscribed to the stream + self.send_command(RdataCommand(stream_name, token, data)) + elif stream_name in self.connecting_streams: + # The client is being subscribed to the stream + logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) + self.pending_rdata.setdefault(stream_name, []).append((token, data)) + else: + # The client isn't subscribed + logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + + def send_sync(self, data): + self.send_command(SyncCommand(data)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + self.streamer.lost_connection(self) + + +class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS + + def __init__(self, client_name, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock) + + self.client_name = client_name + self.server_name = server_name + self.handler = handler + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self.pending_batches = {} + + def connectionMade(self): + self.send_command(NameCommand(self.client_name)) + BaseReplicationStreamProtocol.connectionMade(self) + + # Once we've connected subscribe to the necessary streams + for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + self.replicate(stream_name, token) + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.handler.get_currently_syncing_users() + now = self.clock.time_msec() + for user_id in currently_syncing: + self.send_command(UserSyncCommand(user_id, True, now)) + + # We've now finished connecting to so inform the client handler + self.handler.update_connection(self) + + def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") + + def on_RDATA(self, cmd): + stream_name = cmd.stream_name + inbound_rdata_count.inc(stream_name) + + try: + row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) + except Exception: + logger.exception( + "[%s] Failed to parse RDATA: %r %r", + self.id(), stream_name, cmd.row + ) + raise + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream. Batch + # until we get an update for the stream with a non None token + self.pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self.pending_batches.pop(stream_name, []) + rows.append(row) + + self.handler.on_rdata(stream_name, cmd.token, rows) + + def on_POSITION(self, cmd): + self.handler.on_position(cmd.stream_name, cmd.token) + + def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) + + def replicate(self, stream_name, token): + """Send the subscription request to the server + """ + if stream_name not in STREAMS_MAP: + raise Exception("Invalid stream name %r" % (stream_name,)) + + logger.info( + "[%s] Subscribing to replication stream: %r from %r", + self.id(), stream_name, token + ) + + self.send_command(ReplicateCommand(stream_name, token)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + self.handler.update_connection(None) + + +# The following simply registers metrics for the replication connections + +metrics.register_callback( + "pending_commands", + lambda: { + (p.name, p.conn_id): len(p.pending_commands) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_buffer_size(protocol): + if protocol.transport: + size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen + return size + return 0 + + +metrics.register_callback( + "transport_send_buffer", + lambda: { + (p.name, p.conn_id): transport_buffer_size(p) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_kernel_read_buffer_size(protocol, read=True): + SIOCINQ = 0x541B + SIOCOUTQ = 0x5411 + + if protocol.transport: + fileno = protocol.transport.getHandle().fileno() + if read: + op = SIOCINQ + else: + op = SIOCOUTQ + size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + return size + return 0 + + +metrics.register_callback( + "transport_kernel_send_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +metrics.register_callback( + "transport_kernel_read_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +metrics.register_callback( + "inbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.inbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) + +metrics.register_callback( + "outbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.outbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) + +# number of updates received for each RDATA stream +inbound_rdata_count = metrics.register_counter( + "inbound_rdata_count", + labels=["stream_name"], +) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py new file mode 100644 index 0000000000..a41af4fd6c --- /dev/null +++ b/synapse/replication/tcp/resource.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The server side of the replication stream. +""" + +from twisted.internet import defer, reactor +from twisted.internet.protocol import Factory + +from .streams import STREAMS_MAP, FederationStream +from .protocol import ServerReplicationStreamProtocol + +from synapse.util.metrics import Measure, measure_func + +import logging +import synapse.metrics + + +metrics = synapse.metrics.get_metrics_for(__name__) +stream_updates_counter = metrics.register_counter( + "stream_updates", labels=["stream_name"] +) +user_sync_counter = metrics.register_counter("user_sync") +federation_ack_counter = metrics.register_counter("federation_ack") +remove_pusher_counter = metrics.register_counter("remove_pusher") +invalidate_cache_counter = metrics.register_counter("invalidate_cache") +user_ip_cache_counter = metrics.register_counter("user_ip_cache") + +logger = logging.getLogger(__name__) + + +class ReplicationStreamProtocolFactory(Factory): + """Factory for new replication connections. + """ + def __init__(self, hs): + self.streamer = ReplicationStreamer(hs) + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + def buildProtocol(self, addr): + return ServerReplicationStreamProtocol( + self.server_name, + self.clock, + self.streamer, + addr + ) + + +class ReplicationStreamer(object): + """Handles replication connections. + + This needs to be poked when new replication data may be available. When new + data is available it will propagate to all connected clients. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + + # Current connections. + self.connections = [] + + metrics.register_callback("total_connections", lambda: len(self.connections)) + + # List of streams that clients can subscribe to. + # We only support federation stream if federation sending hase been + # disabled on the master. + self.streams = [ + stream(hs) for stream in STREAMS_MAP.itervalues() + if stream != FederationStream or not hs.config.send_federation + ] + + self.streams_by_name = {stream.NAME: stream for stream in self.streams} + + metrics.register_callback( + "connections_per_stream", + lambda: { + (stream_name,): len([ + conn for conn in self.connections + if stream_name in conn.replication_streams + ]) + for stream_name in self.streams_by_name + }, + labels=["stream_name"], + ) + + self.federation_sender = None + if not hs.config.send_federation: + self.federation_sender = hs.get_federation_sender() + + self.notifier.add_replication_callback(self.on_notifier_poke) + + # Keeps track of whether we are currently checking for updates + self.is_looping = False + self.pending_updates = False + + reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + + def on_shutdown(self): + # close all connections on shutdown + for conn in self.connections: + conn.send_error("server shutting down") + + @defer.inlineCallbacks + def on_notifier_poke(self): + """Checks if there is actually any new data and sends it to the + connections if there are. + + This should get called each time new data is available, even if it + is currently being executed, so that nothing gets missed + """ + if not self.connections: + # Don't bother if nothing is listening. We still need to advance + # the stream tokens otherwise they'll fall beihind forever + for stream in self.streams: + stream.discard_updates_and_advance() + return + + # If we're in the process of checking for new updates, mark that fact + # and return + if self.is_looping: + logger.debug("Noitifier poke loop already running") + self.pending_updates = True + return + + self.pending_updates = True + self.is_looping = True + + try: + # Keep looping while there have been pokes about potential updates. + # This protects against the race where a stream we already checked + # gets an update while we're handling other streams. + while self.pending_updates: + self.pending_updates = False + + with Measure(self.clock, "repl.stream.get_updates"): + # First we tell the streams that they should update their + # current tokens. + for stream in self.streams: + stream.advance_current_token() + + for stream in self.streams: + if stream.last_token == stream.upto_token: + continue + + logger.debug( + "Getting stream: %s: %s -> %s", + stream.NAME, stream.last_token, stream.upto_token + ) + try: + updates, current_token = yield stream.get_updates() + except Exception: + logger.info("Failed to handle stream %s", stream.NAME) + raise + + logger.debug( + "Sending %d updates to %d connections", + len(updates), len(self.connections), + ) + + if updates: + logger.info( + "Streaming: %s -> %s", stream.NAME, updates[-1][0] + ) + stream_updates_counter.inc_by(len(updates), stream.NAME) + + # Some streams return multiple rows with the same stream IDs, + # we need to make sure they get sent out in batches. We do + # this by setting the current token to all but the last of + # a series of updates with the same token to have a None + # token. See RdataCommand for more details. + batched_updates = _batch_updates(updates) + + for conn in self.connections: + for token, row in batched_updates: + try: + conn.stream_update(stream.NAME, token, row) + except Exception: + logger.exception("Failed to replicate") + + logger.debug("No more pending updates, breaking poke loop") + finally: + self.pending_updates = False + self.is_looping = False + + @measure_func("repl.get_stream_updates") + def get_stream_updates(self, stream_name, token): + """For a given stream get all updates since token. This is called when + a client first subscribes to a stream. + """ + stream = self.streams_by_name.get(stream_name, None) + if not stream: + raise Exception("unknown stream %s", stream_name) + + return stream.get_updates_since(token) + + @measure_func("repl.federation_ack") + def federation_ack(self, token): + """We've received an ack for federation stream from a client. + """ + federation_ack_counter.inc() + if self.federation_sender: + self.federation_sender.federation_ack(token) + + @measure_func("repl.on_user_sync") + @defer.inlineCallbacks + def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + """A client has started/stopped syncing on a worker. + """ + user_sync_counter.inc() + yield self.presence_handler.update_external_syncs_row( + conn_id, user_id, is_syncing, last_sync_ms, + ) + + @measure_func("repl.on_remove_pusher") + @defer.inlineCallbacks + def on_remove_pusher(self, app_id, push_key, user_id): + """A client has asked us to remove a pusher + """ + remove_pusher_counter.inc() + yield self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id=app_id, pushkey=push_key, user_id=user_id + ) + + self.notifier.on_new_replication_data() + + @measure_func("repl.on_invalidate_cache") + def on_invalidate_cache(self, cache_func, keys): + """The client has asked us to invalidate a cache + """ + invalidate_cache_counter.inc() + getattr(self.store, cache_func).invalidate(tuple(keys)) + + @measure_func("repl.on_user_ip") + @defer.inlineCallbacks + def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + """The client saw a user request + """ + user_ip_cache_counter.inc() + yield self.store.insert_client_ip( + user_id, access_token, ip, user_agent, device_id, last_seen, + ) + + def send_sync_to_all_connections(self, data): + """Sends a SYNC command to all clients. + + Used in tests. + """ + for conn in self.connections: + conn.send_sync(data) + + def new_connection(self, connection): + """A new client connection has been established + """ + self.connections.append(connection) + + def lost_connection(self, connection): + """A client connection has been lost + """ + try: + self.connections.remove(connection) + except ValueError: + pass + + # We need to tell the presence handler that the connection has been + # lost so that it can handle any ongoing syncs on that connection. + self.presence_handler.update_external_syncs_clear(connection.conn_id) + + +def _batch_updates(updates): + """Takes a list of updates of form [(token, row)] and sets the token to + None for all rows where the next row has the same token. This is used to + implement batching. + + For example: + + [(1, _), (1, _), (2, _), (3, _), (3, _)] + + becomes: + + [(None, _), (1, _), (2, _), (None, _), (3, _)] + """ + if not updates: + return [] + + new_updates = [] + for i, update in enumerate(updates[:-1]): + if update[0] == updates[i + 1][0]: + new_updates.append((None, update[1])) + else: + new_updates.append(update) + + new_updates.append(updates[-1]) + return new_updates diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py new file mode 100644 index 0000000000..4c60bf79f9 --- /dev/null +++ b/synapse/replication/tcp/streams.py @@ -0,0 +1,506 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines all the valid streams that clients can subscribe to, and the format +of the rows returned by each stream. + +Each stream is defined by the following information: + + stream name: The name of the stream + row type: The type that is used to serialise/deserialse the row + current_token: The function that returns the current token for the stream + update_function: The function that returns a list of updates between two tokens +""" + +from twisted.internet import defer +from collections import namedtuple + +import logging + + +logger = logging.getLogger(__name__) + + +MAX_EVENTS_BEHIND = 10000 + + +EventStreamRow = namedtuple("EventStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +BackfillStreamRow = namedtuple("BackfillStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +PresenceStreamRow = namedtuple("PresenceStreamRow", ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool +)) +TypingStreamRow = namedtuple("TypingStreamRow", ( + "room_id", # str + "user_ids", # list(str) +)) +ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict +)) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( + "user_id", # str +)) +PushersStreamRow = namedtuple("PushersStreamRow", ( + "user_id", # str + "app_id", # str + "pushkey", # str + "deleted", # bool +)) +CachesStreamRow = namedtuple("CachesStreamRow", ( + "cache_func", # str + "keys", # list(str) + "invalidation_ts", # int +)) +PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional +)) +DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( + "user_id", # str + "destination", # str +)) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( + "entity", # str +)) +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) +TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( + "user_id", # str + "room_id", # str + "data", # dict +)) +AccountDataStreamRow = namedtuple("AccountDataStream", ( + "user_id", # str + "room_id", # str + "data_type", # str + "data", # dict +)) +CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( + "room_id", # str + "type", # str + "state_key", # str + "event_id", # str, optional +)) +GroupsStreamRow = namedtuple("GroupsStreamRow", ( + "group_id", # str + "user_id", # str + "type", # str + "content", # dict +)) + + +class Stream(object): + """Base class for the streams. + + Provides a `get_updates()` function that returns new updates since the last + time it was called up until the point `advance_current_token` was called. + """ + NAME = None # The name of the stream + ROW_TYPE = None # The type of the row + _LIMITED = True # Whether the update function takes a limit + + def __init__(self, hs): + # The token from which we last asked for updates + self.last_token = self.current_token() + + # The token that we will get updates up to + self.upto_token = self.current_token() + + def advance_current_token(self): + """Updates `upto_token` to "now", which updates up until which point + get_updates[_since] will fetch rows till. + """ + self.upto_token = self.current_token() + + def discard_updates_and_advance(self): + """Called when the stream should advance but the updates would be discarded, + e.g. when there are no currently connected workers. + """ + self.upto_token = self.current_token() + self.last_token = self.upto_token + + @defer.inlineCallbacks + def get_updates(self): + """Gets all updates since the last time this function was called (or + since the stream was constructed if it hadn't been called before), + until the `upto_token` + + Returns: + (list(ROW_TYPE), int): list of updates plus the token used as an + upper bound of the updates (i.e. the "current token") + """ + updates, current_token = yield self.get_updates_since(self.last_token) + self.last_token = current_token + + defer.returnValue((updates, current_token)) + + @defer.inlineCallbacks + def get_updates_since(self, from_token): + """Like get_updates except allows specifying from when we should + stream updates + + Returns: + (list(ROW_TYPE), int): list of updates plus the token used as an + upper bound of the updates (i.e. the "current token") + """ + if from_token in ("NOW", "now"): + defer.returnValue(([], self.upto_token)) + + current_token = self.upto_token + + from_token = int(from_token) + + if from_token == current_token: + defer.returnValue(([], current_token)) + + if self._LIMITED: + rows = yield self.update_function( + from_token, current_token, + limit=MAX_EVENTS_BEHIND + 1, + ) + + if len(rows) >= MAX_EVENTS_BEHIND: + raise Exception("stream %s has fallen behined" % (self.NAME)) + else: + rows = yield self.update_function( + from_token, current_token, + ) + + updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + + defer.returnValue((updates, current_token)) + + def current_token(self): + """Gets the current token of the underlying streams. Should be provided + by the sub classes + + Returns: + int + """ + raise NotImplementedError() + + def update_function(self, from_token, current_token, limit=None): + """Get updates between from_token and to_token. If Stream._LIMITED is + True then limit is provided, otherwise it's not. + + Returns: + Deferred(list(tuple)): the first entry in the tuple is the token for + that update, and the rest of the tuple gets used to construct + a ``ROW_TYPE`` instance + """ + raise NotImplementedError() + + +class EventsStream(Stream): + """We received a new event, or an event went from being an outlier to not + """ + NAME = "events" + ROW_TYPE = EventStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_events_token + self.update_function = store.get_all_new_forward_event_rows + + super(EventsStream, self).__init__(hs) + + +class BackfillStream(Stream): + """We fetched some old events and either we had never seen that event before + or it went from being an outlier to not. + """ + NAME = "backfill" + ROW_TYPE = BackfillStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_backfill_token + self.update_function = store.get_all_new_backfill_event_rows + + super(BackfillStream, self).__init__(hs) + + +class PresenceStream(Stream): + NAME = "presence" + _LIMITED = False + ROW_TYPE = PresenceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + presence_handler = hs.get_presence_handler() + + self.current_token = store.get_current_presence_token + self.update_function = presence_handler.get_all_presence_updates + + super(PresenceStream, self).__init__(hs) + + +class TypingStream(Stream): + NAME = "typing" + _LIMITED = False + ROW_TYPE = TypingStreamRow + + def __init__(self, hs): + typing_handler = hs.get_typing_handler() + + self.current_token = typing_handler.get_current_token + self.update_function = typing_handler.get_all_typing_updates + + super(TypingStream, self).__init__(hs) + + +class ReceiptsStream(Stream): + NAME = "receipts" + ROW_TYPE = ReceiptsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_receipt_stream_id + self.update_function = store.get_all_updated_receipts + + super(ReceiptsStream, self).__init__(hs) + + +class PushRulesStream(Stream): + """A user has changed their push rules + """ + NAME = "push_rules" + ROW_TYPE = PushRulesStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + super(PushRulesStream, self).__init__(hs) + + def current_token(self): + push_rules_token, _ = self.store.get_push_rules_stream_token() + return push_rules_token + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + defer.returnValue([(row[0], row[2]) for row in rows]) + + +class PushersStream(Stream): + """A user has added/changed/removed a pusher + """ + NAME = "pushers" + ROW_TYPE = PushersStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_pushers_stream_token + self.update_function = store.get_all_updated_pushers_rows + + super(PushersStream, self).__init__(hs) + + +class CachesStream(Stream): + """A cache was invalidated on the master and no other stream would invalidate + the cache on the workers + """ + NAME = "caches" + ROW_TYPE = CachesStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_cache_stream_token + self.update_function = store.get_all_updated_caches + + super(CachesStream, self).__init__(hs) + + +class PublicRoomsStream(Stream): + """The public rooms list changed + """ + NAME = "public_rooms" + ROW_TYPE = PublicRoomsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_current_public_room_stream_id + self.update_function = store.get_all_new_public_rooms + + super(PublicRoomsStream, self).__init__(hs) + + +class DeviceListsStream(Stream): + """Someone added/changed/removed a device + """ + NAME = "device_lists" + _LIMITED = False + ROW_TYPE = DeviceListsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_device_list_changes_for_remotes + + super(DeviceListsStream, self).__init__(hs) + + +class ToDeviceStream(Stream): + """New to_device messages for a client + """ + NAME = "to_device" + ROW_TYPE = ToDeviceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_to_device_stream_token + self.update_function = store.get_all_new_device_messages + + super(ToDeviceStream, self).__init__(hs) + + +class FederationStream(Stream): + """Data to be sent over federation. Only available when master has federation + sending disabled. + """ + NAME = "federation" + ROW_TYPE = FederationStreamRow + + def __init__(self, hs): + federation_sender = hs.get_federation_sender() + + self.current_token = federation_sender.get_current_token + self.update_function = federation_sender.get_replication_rows + + super(FederationStream, self).__init__(hs) + + +class TagAccountDataStream(Stream): + """Someone added/removed a tag for a room + """ + NAME = "tag_account_data" + ROW_TYPE = TagAccountDataStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_account_data_stream_id + self.update_function = store.get_all_updated_tags + + super(TagAccountDataStream, self).__init__(hs) + + +class AccountDataStream(Stream): + """Global or per room account data was changed + """ + NAME = "account_data" + ROW_TYPE = AccountDataStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + + self.current_token = self.store.get_max_account_data_stream_id + + super(AccountDataStream, self).__init__(hs) + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + global_results, room_results = yield self.store.get_all_updated_account_data( + from_token, from_token, to_token, limit + ) + + results = list(room_results) + results.extend( + (stream_id, user_id, None, account_data_type, content,) + for stream_id, user_id, account_data_type, content in global_results + ) + + defer.returnValue(results) + + +class CurrentStateDeltaStream(Stream): + """Current state for a room was changed + """ + NAME = "current_state_deltas" + ROW_TYPE = CurrentStateDeltaStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_current_state_delta_stream_id + self.update_function = store.get_all_updated_current_state_deltas + + super(CurrentStateDeltaStream, self).__init__(hs) + + +class GroupServerStream(Stream): + NAME = "groups" + ROW_TYPE = GroupsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_group_stream_token + self.update_function = store.get_all_groups_changes + + super(GroupServerStream, self).__init__(hs) + + +STREAMS_MAP = { + stream.NAME: stream + for stream in ( + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + CurrentStateDeltaStream, + GroupServerStream, + ) +} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index f9f5a3e077..16f5a73b95 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -40,6 +40,7 @@ from synapse.rest.client.v2_alpha import ( register, auth, receipts, + read_marker, keys, tokenrefresh, tags, @@ -50,6 +51,8 @@ from synapse.rest.client.v2_alpha import ( devices, thirdparty, sendtodevice, + user_directory, + groups, ) from synapse.http.server import JsonResource @@ -88,6 +91,7 @@ class ClientRestResource(JsonResource): register.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) + read_marker.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) @@ -98,3 +102,5 @@ class ClientRestResource(JsonResource): devices.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) + user_directory.register_servlets(hs, client_resource) + groups.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 29fcd72375..efd5c9873d 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +16,9 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError -from synapse.types import UserID +from synapse.api.constants import Membership +from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError +from synapse.types import UserID, create_requester from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -112,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): class PurgeHistoryRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns( - "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" ) def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ super(PurgeHistoryRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): @@ -127,17 +135,114 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - yield self.handlers.message_handler.purge_history(room_id, event_id) + body = parse_json_object_from_request(request, allow_empty_body=True) + + delete_local_events = bool(body.get("delete_local_events", False)) + + # establish the topological ordering we should keep events from. The + # user can provide an event_id in the URL or the request body, or can + # provide a timestamp in the request body. + if event_id is None: + event_id = body.get('purge_up_to_event_id') + + if event_id is not None: + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + depth = event.depth + logger.info( + "[purge] purging up to depth %i (event_id %s)", + depth, event_id, + ) + elif 'purge_up_to_ts' in body: + ts = body['purge_up_to_ts'] + if not isinstance(ts, int): + raise SynapseError( + 400, "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, + ) + + stream_ordering = ( + yield self.store.find_first_stream_ordering_after_ts(ts) + ) + + room_event_after_stream_ordering = ( + yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + ) + if room_event_after_stream_ordering: + (_, depth, _) = room_event_after_stream_ordering + else: + logger.warn( + "[purge] purging events not possible: No event found " + "(received_ts %i => stream_ordering %i)", + ts, stream_ordering, + ) + raise SynapseError( + 404, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, + ) + logger.info( + "[purge] purging up to depth %i (received_ts %i => " + "stream_ordering %i)", + depth, ts, stream_ordering, + ) + else: + raise SynapseError( + 400, + "must specify purge_up_to_event_id or purge_up_to_ts", + errcode=Codes.BAD_JSON, + ) + + purge_id = yield self.handlers.message_handler.start_purge_history( + room_id, depth, + delete_local_events=delete_local_events, + ) - defer.returnValue((200, {})) + defer.returnValue((200, { + "purge_id": purge_id, + })) + + +class PurgeHistoryStatusRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history_status/(?P<purge_id>[^/]+)" + ) + + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ + super(PurgeHistoryStatusRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, purge_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + purge_status = self.handlers.message_handler.get_purge_status(purge_id) + if purge_status is None: + raise NotFoundError("purge id '%s' not found" % purge_id) + + defer.returnValue((200, purge_status.asdict())) class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() super(DeactivateAccountRestServlet, self).__init__(hs) + self._deactivate_account_handler = hs.get_deactivate_account_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -148,18 +253,171 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(target_user_id) - yield self.store.user_delete_threepids(target_user_id) - yield self.store.user_set_password_hash(target_user_id, None) - + yield self._deactivate_account_handler.deactivate_account(target_user_id) defer.returnValue((200, {})) +class ShutdownRoomRestServlet(ClientV1RestServlet): + """Shuts down a room by removing all local users from the room and blocking + all future invites and joins to the room. Any local aliases will be repointed + to a new room created by `new_room_user_id` and kicked users will be auto + joined to the new room. + """ + PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)") + + DEFAULT_MESSAGE = ( + "Sharing illegal content on this server is not permitted and rooms in" + " violation will be blocked." + ) + + def __init__(self, hs): + super(ShutdownRoomRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + self.state = hs.get_state_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + content = parse_json_object_from_request(request) + + new_room_user_id = content.get("new_room_user_id") + if not new_room_user_id: + raise SynapseError(400, "Please provide field `new_room_user_id`") + + room_creator_requester = create_requester(new_room_user_id) + + message = content.get("message", self.DEFAULT_MESSAGE) + room_name = content.get("room_name", "Content Violation Notification") + + info = yield self.handlers.room_creation_handler.create_room( + room_creator_requester, + config={ + "preset": "public_chat", + "name": room_name, + "power_level_content_override": { + "users_default": -10, + }, + }, + ratelimit=False, + ) + new_room_id = info["room_id"] + + yield self.event_creation_handler.create_and_send_nonmember_event( + room_creator_requester, + { + "type": "m.room.message", + "content": {"body": message, "msgtype": "m.text"}, + "room_id": new_room_id, + "sender": new_room_user_id, + }, + ratelimit=False, + ) + + requester_user_id = requester.user.to_string() + + logger.info("Shutting down room %r", room_id) + + yield self.store.block_room(room_id, requester_user_id) + + users = yield self.state.get_current_user_in_room(room_id) + kicked_users = [] + for user_id in users: + if not self.hs.is_mine_id(user_id): + continue + + logger.info("Kicking %r from %r...", user_id, room_id) + + target_requester = create_requester(user_id) + yield self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=room_id, + action=Membership.LEAVE, + content={}, + ratelimit=False + ) + + yield self.room_member_handler.forget(target_requester.user, room_id) + + yield self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False + ) + + kicked_users.append(user_id) + + aliases_for_room = yield self.store.get_aliases_for_room(room_id) + + yield self.store.update_aliases_for_room( + room_id, new_room_id, requester_user_id + ) + + defer.returnValue((200, { + "kicked_users": kicked_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + })) + + +class QuarantineMediaInRoom(ClientV1RestServlet): + """Quarantines all media in a room so that no one can download it via + this server. + """ + PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)") + + def __init__(self, hs): + super(QuarantineMediaInRoom, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + num_quarantined = yield self.store.quarantine_media_ids_in_room( + room_id, requester.user.to_string(), + ) + + defer.returnValue((200, {"num_quarantined": num_quarantined})) + + +class ListMediaInRoom(ClientV1RestServlet): + """Lists all of the media in a given room. + """ + PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media") + + def __init__(self, hs): + super(ListMediaInRoom, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) + + defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) + + class ResetPasswordRestServlet(ClientV1RestServlet): """Post request to allow an administrator reset password for a user. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ @user:to_reset_password?access_token=admin_access_token @@ -177,12 +435,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet): super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() + self._set_password_handler = hs.get_set_password_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): """Post request to allow an administrator reset password for a user. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. """ UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -198,7 +456,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): logger.info("new_password: %r", new_password) - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( target_user_id, new_password, requester ) defer.returnValue((200, {})) @@ -206,7 +464,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Get request to get specific number of users from Synapse. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ @admin:user?access_token=admin_access_token&start=0&limit=10 @@ -225,7 +483,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, target_user_id): """Get request to get specific number of users from Synapse. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. """ target_user = UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -258,7 +516,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, target_user_id): """Post request to get specific number of users from Synapse.. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ @admin:user?access_token=admin_access_token @@ -296,7 +554,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): class SearchUsersRestServlet(ClientV1RestServlet): """Get request to search user table for specific users according to search term. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/search_users/ @admin:user?access_token=admin_access_token&term=alice @@ -316,7 +574,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): def on_GET(self, request, target_user_id): """Get request to search user table for specific users according to search term. - This need a user have a administrator access in Synapse. + This needs user to have a administrator access in Synapse. """ target_user = UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -347,9 +605,13 @@ class SearchUsersRestServlet(ClientV1RestServlet): def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) + PurgeHistoryStatusRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) GetUsersPaginatedRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) + ShutdownRoomRestServlet(hs).register(http_server) + QuarantineMediaInRoom(hs).register(http_server) + ListMediaInRoom(hs).register(http_server) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..197335d7aa 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8930f1826f..1c3933380f 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -39,6 +39,7 @@ class ClientDirectoryServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryServer, self).__init__(hs) + self.store = hs.get_datastore() self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -70,7 +71,10 @@ class ClientDirectoryServer(ClientV1RestServlet): logger.debug("Got servers: %s", servers) # TODO(erikj): Check types. - # TODO(erikj): Check that room exists + + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Room does not exist") dir_handler = self.handlers.directory_handler @@ -89,7 +93,7 @@ class ClientDirectoryServer(ClientV1RestServlet): ) except SynapseError as e: raise e - except: + except Exception: logger.exception("Failed to create association") raise except AuthError: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 72057f1b0c..34df5be4e9 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,12 +19,13 @@ from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from synapse.http.server import finish_request from synapse.http.servlet import parse_json_object_from_request +from synapse.util.msisdn import phone_number_to_msisdn from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib -import urlparse +from six.moves.urllib import parse as urlparse import logging from saml2 import BINDING_HTTP_POST @@ -33,13 +34,57 @@ from saml2.client import Saml2Client import xml.etree.ElementTree as ET +from twisted.web.client import PartialDownloadError + logger = logging.getLogger(__name__) +def login_submission_legacy_convert(submission): + """ + If the input login submission is an old style object + (ie. with top-level user / medium / address) convert it + to a typed object. + """ + if "user" in submission: + submission["identifier"] = { + "type": "m.id.user", + "user": submission["user"], + } + del submission["user"] + + if "medium" in submission and "address" in submission: + submission["identifier"] = { + "type": "m.id.thirdparty", + "medium": submission["medium"], + "address": submission["address"], + } + del submission["medium"] + del submission["address"] + + +def login_id_thirdparty_from_phone(identifier): + """ + Convert a phone login identifier type to a generic threepid identifier + Args: + identifier(dict): Login identifier dict of type 'm.id.phone' + + Returns: Login identifier dict of type 'm.id.threepid' + """ + if "country" not in identifier or "number" not in identifier: + raise SynapseError(400, "Invalid phone-type identifier") + + msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) + + return { + "type": "m.id.thirdparty", + "medium": "msisdn", + "address": msisdn, + } + + class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") - PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" @@ -48,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url - self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret @@ -75,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.password_enabled: - flows.append({"type": LoginRestServlet.PASS_TYPE}) + + flows.extend(( + {"type": t} for t in self.auth_handler.get_supported_login_types() + )) return (200, {"flows": flows}) @@ -87,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if login_submission["type"] == LoginRestServlet.PASS_TYPE: - if not self.password_enabled: - raise SynapseError(400, "Password login has been disabled.") - - result = yield self.do_password_login(login_submission) - defer.returnValue(result) - elif self.saml2_enabled and (login_submission["type"] == - LoginRestServlet.SAML2_TYPE): + if self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState=" + urllib.quote( @@ -111,49 +151,102 @@ class LoginRestServlet(ClientV1RestServlet): result = yield self.do_token_login(login_submission) defer.returnValue(result) else: - raise SynapseError(400, "Bad login type.") + result = yield self._do_other_login(login_submission) + defer.returnValue(result) except KeyError: raise SynapseError(400, "Missing JSON keys.") @defer.inlineCallbacks - def do_password_login(self, login_submission): - if 'medium' in login_submission and 'address' in login_submission: - address = login_submission['address'] - if login_submission['medium'] == 'email': + def _do_other_login(self, login_submission): + """Handle non-token/saml/jwt logins + + Args: + login_submission: + + Returns: + (int, object): HTTP code/response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get('identifier'), + login_submission.get('medium'), + login_submission.get('address'), + login_submission.get('user'), + ) + login_submission_legacy_convert(login_submission) + + if "identifier" not in login_submission: + raise SynapseError(400, "Missing param: identifier") + + identifier = login_submission["identifier"] + if "type" not in identifier: + raise SynapseError(400, "Login identifier has no type") + + # convert phone type identifiers to generic threepids + if identifier["type"] == "m.id.phone": + identifier = login_id_thirdparty_from_phone(identifier) + + # convert threepid identifiers to user IDs + if identifier["type"] == "m.id.thirdparty": + address = identifier.get('address') + medium = identifier.get('medium') + + if medium is None or address is None: + raise SynapseError(400, "Invalid thirdparty identifier") + + if medium == 'email': # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], address + medium, address, ) if not user_id: + logger.warn( + "unknown 3pid identifier medium %s, address %r", + medium, address, + ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - else: - user_id = login_submission['user'] - if not user_id.startswith('@'): - user_id = UserID.create( - user_id, self.hs.hostname - ).to_string() + identifier = { + "type": "m.id.user", + "user": user_id, + } + + # by this point, the identifier should be an m.id.user: if it's anything + # else, we haven't understood it. + if identifier["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + if "user" not in identifier: + raise SynapseError(400, "User identifier is missing 'user' key") auth_handler = self.auth_handler - user_id = yield auth_handler.validate_password_login( - user_id=user_id, - password=login_submission["password"], + canonical_user_id, callback = yield auth_handler.validate_login( + identifier["user"], + login_submission, + ) + + device_id = yield self._register_device( + canonical_user_id, login_submission, ) - device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name"), + canonical_user_id, device_id, ) + result = { - "user_id": user_id, # may have changed + "user_id": canonical_user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } + if callback is not None: + yield callback(result) + defer.returnValue((200, result)) @defer.inlineCallbacks @@ -166,7 +259,6 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed @@ -200,7 +292,7 @@ class LoginRestServlet(ClientV1RestServlet): if user is None: raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: @@ -209,7 +301,6 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { @@ -341,7 +432,12 @@ class CasTicketServlet(ClientV1RestServlet): "ticket": request.args["ticket"], "service": self.cas_service_url } - body = yield http_client.get_raw(uri, args) + try: + body = yield http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response result = yield self.handle_cas_response(request, body, client_redirect_url) defer.returnValue(result) @@ -361,7 +457,7 @@ class CasTicketServlet(ClientV1RestServlet): if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if not registered_user_id: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1358d0acab..e092158cb7 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.auth import get_access_token_from_request +from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns @@ -30,15 +31,33 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) - self.store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @defer.inlineCallbacks def on_POST(self, request): - access_token = get_access_token_from_request(request) - yield self.store.delete_access_token(access_token) + try: + requester = yield self.auth.get_user_by_req(request) + except AuthError: + # this implies the access token has already been deleted. + defer.returnValue((401, { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired" + })) + else: + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) + else: + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) + defer.returnValue((200, {})) @@ -47,8 +66,9 @@ class LogoutAllRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutAllRestServlet, self).__init__(hs) - self.store = hs.get_datastore() self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,7 +77,13 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.store.user_delete_access_tokens(user_id) + + # first delete all of the user's devices + yield self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. + yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index eafdce865e..4a73813c58 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID +from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): def __init__(self, hs): super(PresenceStatusRestServlet, self).__init__(hs) self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = yield self.presence_handler.get_state(target_user=user) + state = format_user_presence_state(state, self.clock.time_msec()) defer.returnValue((200, state)) @@ -75,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): raise KeyError() except SynapseError as e: raise e - except: + except Exception: raise SynapseError(400, "Unable to parse state") yield self.presence_handler.set_state(user, state) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1a5045c9ec..e4e3611a14 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileDisplaynameRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) @@ -52,10 +52,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): try: new_name = content["displayname"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileAvatarURLRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) @@ -94,10 +94,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) try: new_name = content["avatar_url"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_avatar_url( + yield self.profile_handler.set_avatar_url( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 9a2ed6ed88..0206e664c1 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -73,6 +73,7 @@ class PushersSetRestServlet(ClientV1RestServlet): def __init__(self, hs): super(PushersSetRestServlet, self).__init__(hs) self.notifier = hs.get_notifier() + self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks def on_POST(self, request): @@ -81,12 +82,10 @@ class PushersSetRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - pusher_pool = self.hs.get_pusherpool() - if ('pushkey' in content and 'app_id' in content and 'kind' in content and content['kind'] is None): - yield pusher_pool.remove_pusher( + yield self.pusher_pool.remove_pusher( content['app_id'], content['pushkey'], user_id=user.to_string() ) defer.returnValue((200, {})) @@ -109,14 +108,14 @@ class PushersSetRestServlet(ClientV1RestServlet): append = content['append'] if not append: - yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content['app_id'], pushkey=content['pushkey'], not_user_id=user.to_string() ) try: - yield pusher_pool.add_pusher( + yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content['kind'], @@ -151,7 +150,8 @@ class PushersRemoveRestServlet(RestServlet): super(RestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() + self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks def on_GET(self, request): @@ -161,10 +161,8 @@ class PushersRemoveRestServlet(RestServlet): app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) - pusher_pool = self.hs.get_pusherpool() - try: - yield pusher_pool.remove_pusher( + yield self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string(), diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ecf7e311a9..9b3022e0b0 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -30,6 +30,8 @@ from hashlib import sha1 import hmac import logging +from six import string_types + logger = logging.getLogger(__name__) @@ -70,10 +72,15 @@ class RegisterRestServlet(ClientV1RestServlet): self.handlers = hs.get_handlers() def on_GET(self, request): + + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] if self.hs.config.enable_registration_captcha: - return ( - 200, - {"flows": [ + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([ { "type": LoginType.RECAPTCHA, "stages": [ @@ -82,27 +89,34 @@ class RegisterRestServlet(ClientV1RestServlet): LoginType.PASSWORD ] }, + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ { "type": LoginType.RECAPTCHA, "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] } - ]} - ) + ]) else: - return ( - 200, - {"flows": [ + # only support the email-only flow if we don't require MSISDN 3PIDs + if require_email or not require_msisdn: + flows.extend([ { "type": LoginType.EMAIL_IDENTITY, "stages": [ LoginType.EMAIL_IDENTITY, LoginType.PASSWORD ] - }, + } + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ { "type": LoginType.PASSWORD } - ]} - ) + ]) + return (200, {"flows": flows}) @defer.inlineCallbacks def on_POST(self, request): @@ -321,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet): def _do_shared_secret(self, request, register_json, session): yield run_on_reactor() - if not isinstance(register_json.get("mac", None), basestring): + if not isinstance(register_json.get("mac", None), string_types): raise SynapseError(400, "Expected mac.") - if not isinstance(register_json.get("user", None), basestring): + if not isinstance(register_json.get("user", None), string_types): raise SynapseError(400, "Expected 'user' key.") - if not isinstance(register_json.get("password", None), basestring): + if not isinstance(register_json.get("password", None), string_types): raise SynapseError(400, "Expected 'password' key.") if not self.hs.config.registration_shared_secret: @@ -336,9 +350,9 @@ class RegisterRestServlet(ClientV1RestServlet): admin = register_json.get("admin", None) # Its important to check as we use null bytes as HMAC field separators - if "\x00" in user: + if b"\x00" in user: raise SynapseError(400, "Invalid user") - if "\x00" in password: + if b"\x00" in password: raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not @@ -346,20 +360,20 @@ class RegisterRestServlet(ClientV1RestServlet): got_mac = str(register_json["mac"]) want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), digestmod=sha1, ) want_mac.update(user) - want_mac.update("\x00") + want_mac.update(b"\x00") want_mac.update(password) - want_mac.update("\x00") - want_mac.update("admin" if admin else "notadmin") + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( - localpart=user, + localpart=user.lower(), password=password, admin=bool(admin), ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 728e3df0e3..fcf9c9ab44 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,9 +28,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, parse_integer ) +from six.moves.urllib import parse as urlparse + import logging -import urllib -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -82,6 +84,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.event_creation_hander = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /room/$roomid/state/$eventtype @@ -154,7 +158,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.handlers.room_member_handler.update_membership( + event = yield self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -162,15 +166,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): content=content, ) else: - msg_handler = self.handlers.message_handler - event, context = yield msg_handler.create_event( + event = yield self.event_creation_hander.create_and_send_nonmember_event( + requester, event_dict, - token_id=requester.access_token_id, txn_id=txn_id, ) - yield msg_handler.send_nonmember_event(requester, event, context) - ret = {} if event: ret = {"event_id": event.event_id} @@ -182,7 +183,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.event_creation_hander = hs.get_event_creation_handler() def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -194,15 +195,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_nonmember_event( + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if 'ts' in request.args and requester.app_service: + event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + + event = yield self.event_creation_hander.create_and_send_nonmember_event( requester, - { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - }, + event_dict, txn_id=txn_id, ) @@ -221,7 +226,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinRoomAliasServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -237,7 +242,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} @@ -246,10 +251,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_id = room_identifier try: remote_room_hosts = request.args["server_name"] - except: + except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): - handler = self.handlers.room_member_handler + handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() @@ -258,7 +263,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_identifier, )) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -397,16 +402,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinedRoomMemberListRestServlet, self).__init__(hs) - self.state = hs.get_state_handler() + self.message_handler = hs.get_handlers().message_handler @defer.inlineCallbacks def on_GET(self, request, room_id): - yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.message_handler.get_joined_members( + requester, room_id, + ) defer.returnValue((200, { - "joined": users_with_profile + "joined": users_with_profile, })) @@ -427,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args filter_bytes = request.args.get("filter", None) if filter_bytes: - filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -484,13 +491,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): defer.returnValue((200, content)) -class RoomEventContext(ClientV1RestServlet): +class RoomEventServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" + ) + + def __init__(self, hs): + super(RoomEventServlet, self).__init__(hs) + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + event = yield self.event_handler.get_event(requester.user, event_id) + + time_now = self.clock.time_msec() + if event: + defer.returnValue((200, serialize_event(event, time_now))) + else: + defer.returnValue((404, "Event not found.")) + + +class RoomEventContextServlet(ClientV1RestServlet): PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" ) def __init__(self, hs): - super(RoomEventContext, self).__init__(hs) + super(RoomEventContextServlet, self).__init__(hs) self.clock = hs.get_clock() self.handlers = hs.get_handlers() @@ -505,7 +534,6 @@ class RoomEventContext(ClientV1RestServlet): room_id, event_id, limit, - requester.is_guest, ) if not results: @@ -531,7 +559,7 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomForgetRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") @@ -544,7 +572,7 @@ class RoomForgetRestServlet(ClientV1RestServlet): allow_guest=False, ) - yield self.handlers.room_member_handler.forget( + yield self.room_member_handler.forget( user=requester.user, room_id=room_id, ) @@ -562,12 +590,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") + "(?P<membership_action>join|invite|leave|ban|unban|kick)") register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -585,13 +613,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.handlers.room_member_handler.do_3pid_invite( + yield self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -613,7 +641,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): if 'reason' in content and membership_action in ['kick', 'ban']: event_content = {'reason': content['reason']} - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -623,7 +651,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content=event_content, ) - defer.returnValue((200, {})) + return_value = {} + + if membership_action == "join": + return_value["room_id"] = room_id + + defer.returnValue((200, return_value)) def _has_3pid_invite_keys(self, content): for key in {"id_server", "medium", "address"}: @@ -641,6 +674,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomRedactEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.event_creation_handler = hs.get_event_creation_handler() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") @@ -651,8 +685,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_nonmember_event( + event = yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -686,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) - room_id = urllib.unquote(room_id) - target_user = UserID.from_string(urllib.unquote(user_id)) + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) @@ -749,8 +782,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) - room_ids = set(r.room_id for r in rooms) # Ensure they're unique. + room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) defer.returnValue((200, {"joined_rooms": list(room_ids)})) @@ -802,4 +834,5 @@ def register_servlets(hs, http_server): RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) - RoomEventContext(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 03141c623c..c43b30b73a 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req( + request, + self.hs.config.turn_allow_guests + ) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 20e765f48f..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -47,3 +48,47 @@ def client_v2_patterns(path_regex, releases=(0,), new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns + + +def set_timeline_upper_limit(filter_json, filter_timeline_limit): + if filter_timeline_limit < 0: + return # no upper limits + timeline = filter_json.get('room', {}).get('timeline', {}) + if 'limit' in timeline: + filter_json['room']['timeline']["limit"] = min( + filter_json['room']['timeline']['limit'], + filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 398e7f5eb0..30523995af 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,27 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from twisted.internet import defer +from synapse.api.auth import has_access_token from synapse.api.constants import LoginType -from synapse.api.errors import LoginError, SynapseError, Codes -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, assert_params_in_request, + parse_json_object_from_request, +) from synapse.util.async import run_on_reactor - -from ._base import client_v2_patterns - -import logging - +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import check_3pid_allowed +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) -class PasswordRequestTokenRestServlet(RestServlet): +class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/password/email/requestToken$") def __init__(self, hs): - super(PasswordRequestTokenRestServlet, self).__init__() + super(EmailPasswordRequestTokenRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -40,14 +44,14 @@ class PasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) + assert_params_in_request(body, [ + 'id_server', 'client_secret', 'email', 'send_attempt' + ]) - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( 'email', body['email'] @@ -60,6 +64,42 @@ class PasswordRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class MsisdnPasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") + + def __init__(self, hs): + super(MsisdnPasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.datastore = self.hs.get_datastore() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_request(body, [ + 'id_server', 'client_secret', + 'country', 'phone_number', 'send_attempt', + ]) + + msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + + existingUid = yield self.datastore.get_user_id_by_threepid( + 'msisdn', msisdn + ) + + if existingUid is None: + raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestMsisdnToken(**body) + defer.returnValue((200, ret)) + + class PasswordRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/password$") @@ -68,55 +108,62 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() + self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - [LoginType.EMAIL_IDENTITY] - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. + + if has_access_token(request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - elif LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() - # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - threepid['medium'], threepid['address'] + params = yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id + user_id = requester.user.to_string() else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = None + result, params, _ = yield self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + body, self.hs.get_ip_from_request(request), + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() + # if using email, we must know about the email they're authing with! + threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) if 'new_password' not in params: raise SynapseError(400, "", Codes.MISSING_PARAM) new_password = params['new_password'] - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( user_id, new_password, requester ) @@ -130,52 +177,43 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/deactivate$") def __init__(self, hs): + super(DeactivateAccountRestServlet, self).__init__() self.hs = hs - self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - super(DeactivateAccountRestServlet, self).__init__() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = yield self.auth.get_user_by_req(request) - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) + # allow ASes to dectivate their own users + if requester.app_service: + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string() + ) + defer.returnValue((200, {})) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + ) defer.returnValue((200, {})) -class ThreepidRequestTokenRestServlet(RestServlet): +class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): self.hs = hs - super(ThreepidRequestTokenRestServlet, self).__init__() + super(EmailThreepidRequestTokenRestServlet, self).__init__() self.identity_handler = hs.get_handlers().identity_handler + self.datastore = self.hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request): @@ -190,7 +228,12 @@ class ThreepidRequestTokenRestServlet(RestServlet): if absent: raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + + existingUid = yield self.datastore.get_user_id_by_threepid( 'email', body['email'] ) @@ -201,6 +244,49 @@ class ThreepidRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class MsisdnThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(MsisdnThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + self.datastore = self.hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = [ + 'id_server', 'client_secret', + 'country', 'phone_number', 'send_attempt', + ] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + + existingUid = yield self.datastore.get_user_id_by_threepid( + 'msisdn', msisdn + ) + + if existingUid is not None: + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestMsisdnToken(**body) + defer.returnValue((200, ret)) + + class ThreepidRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/3pid$") @@ -210,6 +296,7 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() @defer.inlineCallbacks def on_GET(self, request): @@ -217,7 +304,7 @@ class ThreepidRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) - threepids = yield self.hs.get_datastore().user_get_threepids( + threepids = yield self.datastore.user_get_threepids( requester.user.to_string() ) @@ -258,7 +345,7 @@ class ThreepidRestServlet(RestServlet): if 'bind' in body and body['bind']: logger.debug( - "Binding emails %s to %s", + "Binding threepid %s to %s", threepid, user_id ) yield self.identity_handler.bind_threepid( @@ -301,10 +388,27 @@ class ThreepidDeleteRestServlet(RestServlet): defer.returnValue((200, {})) +class WhoamiRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/whoami$") + + def __init__(self, hs): + super(WhoamiRestServlet, self).__init__() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + + defer.returnValue((200, {'user_id': requester.user.to_string()})) + + def register_servlets(hs, http_server): - PasswordRequestTokenRestServlet(hs).register(http_server) + EmailPasswordRequestTokenRestServlet(hs).register(http_server) + MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) - ThreepidRequestTokenRestServlet(hs).register(http_server) + EmailThreepidRequestTokenRestServlet(hs).register(http_server) + MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index b16079cece..0e0a187efd 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -16,7 +16,7 @@ from ._base import client_v2_patterns from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from twisted.internet import defer @@ -82,6 +82,13 @@ class RoomAccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + if account_data_type == "m.fully_read": + raise SynapseError( + 405, + "Cannot set m.fully_read through this API." + " Use /rooms/!roomId:server.name/read_markers" + ) + max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index a1feaf3d54..35d58b367a 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -17,15 +17,15 @@ import logging from twisted.internet import defer -from synapse.api import constants, errors +from synapse.api import errors from synapse.http import servlet -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class DevicesRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): """ @@ -46,9 +46,53 @@ class DevicesRestServlet(servlet.RestServlet): defer.returnValue((200, {"devices": devices})) +class DeleteDevicesRestServlet(servlet.RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. Requires user interactive auth. + """ + PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) + + def __init__(self, hs): + super(DeleteDevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + + try: + body = servlet.parse_json_object_from_request(request) + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # deal with older clients which didn't pass a J*DELETESON dict + # the same as those that pass an empty dict + body = {} + else: + raise e + + if 'devices' not in body: + raise errors.SynapseError( + 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM + ) + + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + + yield self.device_handler.delete_devices( + requester.user.to_string(), + body['devices'], + ) + defer.returnValue((200, {})) + + class DeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", - releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) def __init__(self, hs): """ @@ -70,8 +114,11 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + try: body = servlet.parse_json_object_from_request(request) @@ -83,17 +130,12 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id, ) defer.returnValue((200, {})) @@ -111,5 +153,6 @@ class DeviceRestServlet(servlet.RestServlet): def register_servlets(hs, http_server): + DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index b4084fec62..1b9dc4528d 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -20,6 +20,7 @@ from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID from ._base import client_v2_patterns +from ._base import set_timeline_upper_limit import logging @@ -49,7 +50,7 @@ class GetFilterRestServlet(RestServlet): try: filter_id = int(filter_id) - except: + except Exception: raise SynapseError(400, "Invalid filter_id") try: @@ -85,6 +86,11 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) + set_timeline_upper_limit( + content, + self.hs.config.filter_timeline_limit + ) + filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content, diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py new file mode 100644 index 0000000000..3bb1ec2af6 --- /dev/null +++ b/synapse/rest/client/v2_alpha/groups.py @@ -0,0 +1,786 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import GroupID + +from ._base import client_v2_patterns + +import logging + +logger = logging.getLogger(__name__) + + +class GroupServlet(RestServlet): + """Get the group profile + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") + + def __init__(self, hs): + super(GroupServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + group_description = yield self.groups_handler.get_group_profile( + group_id, + requester_user_id, + ) + + defer.returnValue((200, group_description)) + + @defer.inlineCallbacks + def on_POST(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + yield self.groups_handler.update_group_profile( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, {})) + + +class GroupSummaryServlet(RestServlet): + """Get the full group summary + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") + + def __init__(self, hs): + super(GroupSummaryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + get_group_summary = yield self.groups_handler.get_group_summary( + group_id, + requester_user_id, + ) + + defer.returnValue((200, get_group_summary)) + + +class GroupSummaryRoomsCatServlet(RestServlet): + """Update/delete a rooms entry in the summary. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryRoomsCatServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoryServlet(RestServlet): + """Get/add/update/delete a group category + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupCategoryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_category( + group_id, requester_user_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoriesServlet(RestServlet): + """Get all group categories + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + def __init__(self, hs): + super(GroupCategoriesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupRoleServlet(RestServlet): + """Get/add/update/delete a group role + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_role( + group_id, requester_user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRolesServlet(RestServlet): + """Get all group roles + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + def __init__(self, hs): + super(GroupRolesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupSummaryUsersRoleServlet(RestServlet): + """Update/delete a user's entry in the summary. + + Matches both: + - /groups/:group/summary/users/:room_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryUsersRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRoomServlet(RestServlet): + """Get all rooms in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") + + def __init__(self, hs): + super(GroupRoomServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupUsersServlet(RestServlet): + """Get all users in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") + + def __init__(self, hs): + super(GroupUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") + + def __init__(self, hs): + super(GroupInvitedUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_invited_users_in_group( + group_id, + requester_user_id, + ) + + defer.returnValue((200, result)) + + +class GroupSettingJoinPolicyServlet(RestServlet): + """Set group join policy + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") + + def __init__(self, hs): + super(GroupSettingJoinPolicyServlet, self).__init__() + self.auth = hs.get_auth() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + + result = yield self.groups_handler.set_group_join_policy( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + +class GroupCreateServlet(RestServlet): + """Create a group + """ + PATTERNS = client_v2_patterns("/create_group$") + + def __init__(self, hs): + super(GroupCreateServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.server_name = hs.hostname + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + # TODO: Create group on remote server + content = parse_json_object_from_request(request) + localpart = content.pop("localpart") + group_id = GroupID(localpart, self.server_name).to_string() + + result = yield self.groups_handler.create_group( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsServlet(RestServlet): + """Add a room to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.add_room_to_group( + group_id, requester_user_id, room_id, content, + ) + + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsConfigServlet(RestServlet): + """Update the config of a room in a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsConfigServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id, config_key): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersInviteServlet(RestServlet): + """Invite a user to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.store = hs.get_datastore() + self.is_mine_id = hs.is_mine_id + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + config = content.get("config", {}) + result = yield self.groups_handler.invite( + group_id, user_id, requester_user_id, config, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersKickServlet(RestServlet): + """Kick a user from the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersKickServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfLeaveServlet(RestServlet): + """Leave a joined group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/leave$" + ) + + def __init__(self, hs): + super(GroupSelfLeaveServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, requester_user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfJoinServlet(RestServlet): + """Attempt to join a group, or knock + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/join$" + ) + + def __init__(self, hs): + super(GroupSelfJoinServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.join_group( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfAcceptInviteServlet(RestServlet): + """Accept a group invite + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/accept_invite$" + ) + + def __init__(self, hs): + super(GroupSelfAcceptInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.accept_invite( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfUpdatePublicityServlet(RestServlet): + """Update whether we publicise a users membership of a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/update_publicity$" + ) + + def __init__(self, hs): + super(GroupSelfUpdatePublicityServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + publicise = content["publicise"] + yield self.store.update_group_publicity( + group_id, requester_user_id, publicise, + ) + + defer.returnValue((200, {})) + + +class PublicisedGroupsForUserServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + yield self.auth.get_user_by_req(request, allow_guest=True) + + result = yield self.groups_handler.get_publicised_groups_for_user( + user_id + ) + + defer.returnValue((200, result)) + + +class PublicisedGroupsForUsersServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + user_ids = content["user_ids"] + + result = yield self.groups_handler.bulk_get_publicised_groups( + user_ids + ) + + defer.returnValue((200, result)) + + +class GroupsForUserServlet(RestServlet): + """Get all groups the logged in user is joined to + """ + PATTERNS = client_v2_patterns( + "/joined_groups$" + ) + + def __init__(self, hs): + super(GroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_joined_groups(requester_user_id) + + defer.returnValue((200, result)) + + +def register_servlets(hs, http_server): + GroupServlet(hs).register(http_server) + GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) + GroupUsersServlet(hs).register(http_server) + GroupRoomServlet(hs).register(http_server) + GroupSettingJoinPolicyServlet(hs).register(http_server) + GroupCreateServlet(hs).register(http_server) + GroupAdminRoomsServlet(hs).register(http_server) + GroupAdminRoomsConfigServlet(hs).register(http_server) + GroupAdminUsersInviteServlet(hs).register(http_server) + GroupAdminUsersKickServlet(hs).register(http_server) + GroupSelfLeaveServlet(hs).register(http_server) + GroupSelfJoinServlet(hs).register(http_server) + GroupSelfAcceptInviteServlet(hs).register(http_server) + GroupsForUserServlet(hs).register(http_server) + GroupCategoryServlet(hs).register(http_server) + GroupCategoriesServlet(hs).register(http_server) + GroupSummaryRoomsCatServlet(hs).register(http_server) + GroupRoleServlet(hs).register(http_server) + GroupRolesServlet(hs).register(http_server) + GroupSelfUpdatePublicityServlet(hs).register(http_server) + GroupSummaryUsersRoleServlet(hs).register(http_server) + PublicisedGroupsForUserServlet(hs).register(http_server) + PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 6a3cfe84f8..3cc87ea63f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns( - "/keys/query$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/query$") def __init__(self, hs): """ @@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns( - "/keys/changes$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/changes$") def __init__(self, hs): """ @@ -188,13 +181,11 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - changed = yield self.device_handler.get_user_ids_changed( + results = yield self.device_handler.get_user_ids_changed( user_id, from_token, ) - defer.returnValue((200, { - "changed": list(changed), - })) + defer.returnValue((200, results)) class OneTimeKeyServlet(RestServlet): @@ -215,10 +206,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns( - "/keys/claim$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index fd2a3d69d4..ec170109fe 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$", releases=()) + PATTERNS = client_v2_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py new file mode 100644 index 0000000000..2f8784fe06 --- /dev/null +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class ReadMarkerRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") + + def __init__(self, hs): + super(ReadMarkerRestServlet, self).__init__() + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.read_marker_handler = hs.get_read_marker_handler() + self.presence_handler = hs.get_presence_handler() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + yield self.presence_handler.bump_presence_active_time(requester.user) + + body = parse_json_object_from_request(request) + + read_event_id = body.get("m.read", None) + if read_event_id: + yield self.receipts_handler.received_client_receipt( + room_id, + "m.read", + user_id=requester.user.to_string(), + event_id=read_event_id + ) + + read_marker_event_id = body.get("m.fully_read", None) + if read_marker_event_id: + yield self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=read_marker_event_id + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ccca5a12d5..5cab00aea9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 - 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,17 +17,25 @@ from twisted.internet import defer import synapse +import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string +) +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac from hashlib import sha1 from synapse.util.async import run_on_reactor +from synapse.util.ratelimitutils import FederationRateLimiter + +from six import string_types # We ought to be using hmac.compare_digest() but on older pythons it doesn't @@ -43,7 +52,7 @@ else: logger = logging.getLogger(__name__) -class RegisterRequestTokenRestServlet(RestServlet): +class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register/email/requestToken$") def __init__(self, hs): @@ -51,7 +60,7 @@ class RegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegisterRequestTokenRestServlet, self).__init__() + super(EmailRegisterRequestTokenRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -59,14 +68,14 @@ class RegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) + assert_params_in_request(body, [ + 'id_server', 'client_secret', 'email', 'send_attempt' + ]) - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( 'email', body['email'] @@ -79,6 +88,86 @@ class RegisterRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class MsisdnRegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(MsisdnRegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_request(body, [ + 'id_server', 'client_secret', + 'country', 'phone_number', + 'send_attempt', + ]) + + msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'msisdn', msisdn + ) + + if existingUid is not None: + raise SynapseError( + 400, "Phone number is already in use", Codes.THREEPID_IN_USE + ) + + ret = yield self.identity_handler.requestMsisdnToken(**body) + defer.returnValue((200, ret)) + + +class UsernameAvailabilityRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/available") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(UsernameAvailabilityRestServlet, self).__init__() + self.hs = hs + self.registration_handler = hs.get_handlers().registration_handler + self.ratelimiter = FederationRateLimiter( + hs.get_clock(), + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ) + + @defer.inlineCallbacks + def on_GET(self, request): + ip = self.hs.get_ip_from_request(request) + with self.ratelimiter.ratelimit(ip) as wait_deferred: + yield wait_deferred + + username = parse_string(request, "username", required=True) + + yield self.registration_handler.check_username(username) + + defer.returnValue((200, {"available": True})) + + class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") @@ -95,9 +184,11 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_room_member_handler() self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() @@ -121,14 +212,14 @@ class RegisterRestServlet(RestServlet): # in sessions. Pull out the username/password provided to us. desired_password = None if 'password' in body: - if (not isinstance(body['password'], basestring) or + if (not isinstance(body['password'], string_types) or len(body['password']) > 512): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None if 'username' in body: - if (not isinstance(body['username'], basestring) or + if (not isinstance(body['username'], string_types) or len(body['username']) > 512): raise SynapseError(400, "Invalid username") desired_username = body['username'] @@ -146,15 +237,30 @@ class RegisterRestServlet(RestServlet): # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. desired_username = body.get("user", desired_username) + + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + access_token = get_access_token_from_request(request) - if isinstance(desired_username, basestring): + if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( desired_username, access_token, body ) defer.returnValue((200, result)) # we throw for non 200 responses return + # for either shared secret or regular registration, downcase the + # provided username before attempting to register it. This should mean + # that people who try to register with upper-case in their usernames + # don't get a nasty surprise. (Note that we treat username + # case-insenstively in login, so they are free to carry on imagining + # that their username is CrAzYh4cKeR if that keeps them happy) + if desired_username is not None: + desired_username = desired_username.lower() + # == Shared Secret Registration == (e.g. create new user scripts) if 'mac' in body: # FIXME: Should we really be determining if this is shared secret @@ -200,32 +306,86 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) + # Only give msisdn flows if the x_show_msisdn flag is given: + # this is a hack to work around the fact that clients were shipped + # that use fallback registration if they see any flows that they don't + # recognise, which means we break registration for these clients if we + # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot + # Android <=0.6.9 have fallen below an acceptable threshold, this + # parameter should go away and we should always advertise msisdn flows. + show_msisdn = False + if 'x_show_msisdn' in body and body['x_show_msisdn']: + show_msisdn = True + + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] if self.hs.config.enable_registration_captcha: - flows = [ - [LoginType.RECAPTCHA], - [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] - ] + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([[LoginType.RECAPTCHA]]) + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) + + if show_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs + if not require_email: + flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + # always let users provide both MSISDN & email + flows.extend([ + [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], + ]) else: - flows = [ - [LoginType.DUMMY], - [LoginType.EMAIL_IDENTITY] - ] - - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([[LoginType.DUMMY]]) + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([[LoginType.EMAIL_IDENTITY]]) + + if show_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs + if not require_email or require_msisdn: + flows.extend([[LoginType.MSISDN]]) + # always let users provide both MSISDN & email + flows.extend([ + [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] + ]) + + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return + # Check that we're not trying to register a denied 3pid. + # + # the user-facing checks will probably already have happened in + # /register/email/requestToken when we requested a 3pid, but that's not + # guaranteed. + + if auth_result: + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]['medium'] + address = auth_result[login_type]['address'] + + if not check_3pid_allowed(self.hs, medium, address): + raise SynapseError( + 403, "Third party identifier is not allowed", + Codes.THREEPID_DENIED, + ) if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", registered_user_id ) - # don't re-register the email address + # don't re-register the threepids add_email = False + add_msisdn = False else: # NB: This may be from the auth handler and NOT from the POST if 'password' not in params: @@ -236,6 +396,9 @@ class RegisterRestServlet(RestServlet): new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + if desired_username is not None: + desired_username = desired_username.lower() + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, @@ -250,6 +413,7 @@ class RegisterRestServlet(RestServlet): ) add_email = True + add_msisdn = True return_dict = yield self._create_registration_details( registered_user_id, params @@ -262,6 +426,13 @@ class RegisterRestServlet(RestServlet): params.get("bind_email") ) + if add_msisdn and auth_result and LoginType.MSISDN in auth_result: + threepid = auth_result[LoginType.MSISDN] + yield self._register_msisdn_threepid( + registered_user_id, threepid, return_dict["access_token"], + params.get("bind_msisdn") + ) + defer.returnValue((200, return_dict)) def on_OPTIONS(self, _): @@ -278,15 +449,24 @@ class RegisterRestServlet(RestServlet): def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") + if not username: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) - user = username.encode("utf-8") + # use the username from the original request rather than the + # downcased one in `username` for the mac calculation + user = body["username"].encode("utf-8") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface got_mac = str(body["mac"]) + # FIXME this is different to the /v1/register endpoint, which + # includes the password and admin flag in the hashed text. Why are + # these different? want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), msg=user, digestmod=sha1, ).hexdigest() @@ -323,8 +503,9 @@ class RegisterRestServlet(RestServlet): """ reqd = ('medium', 'address', 'validated_at') if any(x not in threepid for x in reqd): + # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") - defer.returnValue() + return yield self.auth_handler.add_threepid( user_id, @@ -372,6 +553,43 @@ class RegisterRestServlet(RestServlet): logger.info("bind_email not specified: not binding email") @defer.inlineCallbacks + def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn): + """Add a phone number as a 3pid identifier + + Also optionally binds msisdn to the given user_id on the identity server + + Args: + user_id (str): id of user + threepid (object): m.login.msisdn auth response + token (str): access_token for the user + bind_email (bool): true if the client requested the email to be + bound at the identity server + Returns: + defer.Deferred: + """ + reqd = ('medium', 'address', 'validated_at') + if any(x not in threepid for x in reqd): + # This will only happen if the ID server returns a malformed response + logger.info("Can't add incomplete 3pid") + defer.returnValue() + + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + if bind_msisdn: + logger.info("bind_msisdn specified: binding") + logger.debug("Binding msisdn %s to %s", threepid, user_id) + yield self.identity_handler.bind_threepid( + threepid['threepid_creds'], user_id + ) + else: + logger.info("bind_msisdn not specified: not binding msisdn") + + @defer.inlineCallbacks def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user @@ -380,25 +598,28 @@ class RegisterRestServlet(RestServlet): Args: (str) user_id: full canonical @user:id (object) params: registration parameters, from which we pull - device_id and initial_device_name + device_id, initial_device_name and inhibit_login Returns: defer.Deferred: (object) dictionary for response from /register """ - device_id = yield self._register_device(user_id, params) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + } + if not params.get("inhibit_login", False): + device_id = yield self._register_device(user_id, params) - access_token = ( - yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - initial_display_name=params.get("initial_device_display_name") + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) ) - ) - defer.returnValue({ - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - }) + result.update({ + "access_token": access_token, + "device_id": device_id, + }) + defer.returnValue(result) def _register_device(self, user_id, params): """Register a device for a user. @@ -433,7 +654,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - self.device_handler.check_device_registered( + yield self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) @@ -449,5 +670,7 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): - RegisterRequestTokenRestServlet(hs).register(http_server) + EmailRegisterRequestTokenRestServlet(hs).register(http_server) + MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d607bd2970..90bdb1db15 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", - releases=[], v2_alpha=False + v2_alpha=False ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index b3d8001638..eb91c0b293 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.http.servlet import ( RestServlet, parse_string, parse_integer, parse_boolean ) +from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.types import StreamToken from synapse.events.utils import ( @@ -27,12 +28,12 @@ from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState from ._base import client_v2_patterns +from ._base import set_timeline_upper_limit -import copy import itertools import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -78,6 +79,7 @@ class SyncRestServlet(RestServlet): def __init__(self, hs): super(SyncRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() @@ -108,7 +110,7 @@ class SyncRestServlet(RestServlet): filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) - logger.info( + logger.debug( "/sync: user=%r, timeout=%r, since=%r," " set_presence=%r, filter_id=%r, device_id=%r" % ( user, timeout, since, set_presence, filter_id, device_id @@ -121,7 +123,9 @@ class SyncRestServlet(RestServlet): if filter_id.startswith('{'): try: filter_object = json.loads(filter_id) - except: + set_timeline_upper_limit(filter_object, + self.hs.config.filter_timeline_limit) + except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) @@ -160,27 +164,35 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() + response_content = self.encode_response( + time_now, sync_result, requester.access_token_id, filter + ) + + defer.returnValue((200, response_content)) - joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id, filter.event_fields + @staticmethod + def encode_response(time_now, sync_result, access_token_id, filter): + joined = SyncRestServlet.encode_joined( + sync_result.joined, time_now, access_token_id, filter.event_fields ) - invited = self.encode_invited( - sync_result.invited, time_now, requester.access_token_id + invited = SyncRestServlet.encode_invited( + sync_result.invited, time_now, access_token_id, ) - archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, + archived = SyncRestServlet.encode_archived( + sync_result.archived, time_now, access_token_id, filter.event_fields, ) - response_content = { + return { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, "device_lists": { - "changed": list(sync_result.device_lists), + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), }, - "presence": self.encode_presence( + "presence": SyncRestServlet.encode_presence( sync_result.presence, time_now ), "rooms": { @@ -188,20 +200,32 @@ class SyncRestServlet(RestServlet): "invite": invited, "leave": archived, }, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), } - defer.returnValue((200, response_content)) - - def encode_presence(self, events, time_now): - formatted = [] - for event in events: - event = copy.deepcopy(event) - event['sender'] = event['content'].pop('user_id') - formatted.append(event) - return {"events": formatted} + @staticmethod + def encode_presence(events, time_now): + return { + "events": [ + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in events + ] + } - def encode_joined(self, rooms, time_now, token_id, event_fields): + @staticmethod + def encode_joined(rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -220,13 +244,14 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = self.encode_room( + joined[room.room_id] = SyncRestServlet.encode_room( room, time_now, token_id, only_fields=event_fields ) return joined - def encode_invited(self, rooms, time_now, token_id): + @staticmethod + def encode_invited(rooms, time_now, token_id): """ Encode the invited rooms in a sync result @@ -247,6 +272,7 @@ class SyncRestServlet(RestServlet): invite = serialize_event( room.invite, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + is_invite=True, ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -258,7 +284,8 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id, event_fields): + @staticmethod + def encode_archived(rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -277,7 +304,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = self.encode_room( + joined[room.room_id] = SyncRestServlet.encode_room( room, time_now, token_id, joined=False, only_fields=event_fields ) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 31f94bc6e9..6773b9ba60 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -36,15 +36,14 @@ class ThirdPartyProtocolsServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() defer.returnValue((200, protocols)) class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -54,7 +53,7 @@ class ThirdPartyProtocolServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( only_protocol=protocol, @@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -77,7 +75,7 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) @@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() @@ -101,7 +98,7 @@ class ThirdPartyLocationServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py new file mode 100644 index 0000000000..2d4a43c353 --- /dev/null +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class UserDirectorySearchRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/user_directory/search$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(UserDirectorySearchRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.user_directory_handler = hs.get_user_directory_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + body = parse_json_object_from_request(request) + + limit = body.get("limit", 10) + limit = min(limit, 50) + + try: + search_term = body["search_term"] + except Exception: + raise SynapseError(400, "`search_term` is required field") + + results = yield self.user_directory_handler.search_users( + user_id, search_term, limit, + ) + + defer.returnValue((200, results)) + + +def register_servlets(hs, http_server): + UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index e984ea47db..2ecb15deee 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet): "r0.0.1", "r0.1.0", "r0.2.0", + "r0.3.0", ] }) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ff95269ba8..be68d9a096 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -84,12 +84,11 @@ class LocalKey(Resource): } old_verify_keys = {} - for key in self.config.old_signing_keys: - key_id = "%s:%s" % (key.alg, key.version) + for key_id, key in self.config.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { u"key": encode_base64(verify_key_bytes), - u"expired_ts": key.expired, + u"expired_ts": key.expired_ts, } tls_fingerprints = self.config.tls_fingerprints diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9fe2013657..17e6079cba 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -93,6 +93,7 @@ class RemoteKey(Resource): self.store = hs.get_datastore() self.version_string = hs.version_string self.clock = hs.get_clock() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist def render_GET(self, request): self.async_render_GET(request) @@ -137,6 +138,13 @@ class RemoteKey(Resource): logger.info("Handling query for keys %r", query) store_queries = [] for server_name, key_ids in query.items(): + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + logger.debug("Federation denied with %s", server_name) + continue + if not key_ids: key_ids = (None,) for key_id in key_ids: @@ -213,7 +221,7 @@ class RemoteKey(Resource): ) except KeyLookupError as e: logger.info("Failed to fetch key: %s", e) - except: + except Exception: logger.exception("Failed to get key for %r", server_name) yield self.query_keys( request, query, query_remote_on_cache_miss=False diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index b9600f2167..c0d2f06855 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,6 +17,7 @@ from synapse.http.server import respond_with_json, finish_request from synapse.api.errors import ( cs_error, Codes, SynapseError ) +from synapse.util import logcontext from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -27,7 +28,7 @@ import os import logging import urllib -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ def parse_media_id(request): except UnicodeDecodeError: pass return server_name, media_id, file_name - except: + except Exception: raise SynapseError( 404, "Invalid media id token %r" % (request.postpath,), @@ -69,42 +70,133 @@ def respond_with_file(request, media_type, file_path, logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) - if upload_name: - if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) if file_size is None: stat = os.stat(file_path) file_size = stat.st_size - request.setHeader( - b"Content-Length", b"%d" % (file_size,) - ) + add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield FileSender().beginFileTransfer(f, request) + yield logcontext.make_deferred_yieldable( + FileSender().beginFileTransfer(f, request) + ) finish_request(request) else: respond_404(request) + + +def add_file_headers(request, media_type, file_size, upload_name): + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request (twisted.web.http.Request) + media_type (str): The media/content type. + file_size (int): Size in bytes of the media, if known. + upload_name (str): The name of the requested file, if any. + """ + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + if upload_name: + if is_ascii(upload_name): + request.setHeader( + b"Content-Disposition", + b"inline; filename=%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + else: + request.setHeader( + b"Content-Disposition", + b"inline; filename*=utf-8''%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" + ) + + request.setHeader( + b"Content-Length", b"%d" % (file_size,) + ) + + +@defer.inlineCallbacks +def respond_with_responder(request, responder, media_type, file_size, upload_name=None): + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request (twisted.web.http.Request) + responder (Responder|None) + media_type (str): The media/content type. + file_size (int|None): Size in bytes of the media. If not known it should be None + upload_name (str|None): The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + logger.debug("Responding to media request with responder %s") + add_file_headers(request, media_type, file_size, upload_name) + with responder: + yield responder.write_to_consumer(request) + finish_request(request) + + +class Responder(object): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + def write_to_consumer(self, consumer): + """Stream response into consumer + + Args: + consumer (IConsumer) + + Returns: + Deferred: Resolves once the response has finished being written + """ + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class FileInfo(object): + """Details about a requested/uploaded file. + + Attributes: + server_name (str): The server name where the media originated from, + or None if local. + file_id (str): The local ID of the file. For local files this is the + same as the media_id + url_cache (bool): If the file is for the url preview cache + thumbnail (bool): Whether the file is a thumbnail or not. + thumbnail_width (int) + thumbnail_height (int) + thumbnail_method (str) + thumbnail_type (str): Content type of thumbnail, e.g. image/png + """ + def __init__(self, server_name, file_id, url_cache=False, + thumbnail=False, thumbnail_width=None, thumbnail_height=None, + thumbnail_method=None, thumbnail_type=None): + self.server_name = server_name + self.file_id = file_id + self.url_cache = url_cache + self.thumbnail = thumbnail + self.thumbnail_width = thumbnail_width + self.thumbnail_height = thumbnail_height + self.thumbnail_method = thumbnail_method + self.thumbnail_type = thumbnail_type diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index dfb87ffd15..fe7e17596f 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import synapse.http.servlet -from ._base import parse_media_id, respond_with_file, respond_404 +from ._base import parse_media_id, respond_404 from twisted.web.resource import Resource from synapse.http.server import request_handler, set_cors_headers @@ -31,12 +32,12 @@ class DownloadResource(Resource): def __init__(self, hs, media_repo): Resource.__init__(self) - self.filepaths = media_repo.filepaths self.media_repo = media_repo self.server_name = hs.hostname - self.store = hs.get_datastore() - self.version_string = hs.version_string + + # Both of these are expected by @request_handler() self.clock = hs.get_clock() + self.version_string = hs.version_string def render_GET(self, request): self._async_render_GET(request) @@ -56,43 +57,16 @@ class DownloadResource(Resource): ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: - yield self._respond_local_file(request, media_id, name) + yield self.media_repo.get_local_media(request, media_id, name) else: - yield self._respond_remote_file( - request, server_name, media_id, name - ) - - @defer.inlineCallbacks - def _respond_local_file(self, request, media_id, name): - media_info = yield self.store.get_local_media(media_id) - if not media_info: - respond_404(request) - return - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - file_path = self.filepaths.local_media_filepath(media_id) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) - - @defer.inlineCallbacks - def _respond_remote_file(self, request, server_name, media_id, name): - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - filesystem_id = media_info["filesystem_id"] - upload_name = name if name else media_info["upload_name"] - - file_path = self.filepaths.remote_media_filepath( - server_name, filesystem_id - ) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) + allow_remote = synapse.http.servlet.parse_boolean( + request, "allow_remote", default=True) + if not allow_remote: + logger.info( + "Rejecting request for remote media %s/%s due to allow_remote", + server_name, media_id, + ) + respond_404(request) + return + + yield self.media_repo.get_remote_media(request, server_name, media_id, name) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 0137458f71..d5164e47e0 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -14,60 +14,200 @@ # limitations under the License. import os +import re +import functools + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +def _wrap_in_base_path(func): + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + @functools.wraps(func) + def _wrapped(self, *args, **kwargs): + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return _wrapped class MediaFilePaths(object): + """Describes where files are stored on disk. - def __init__(self, base_path): - self.base_path = base_path + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ - def default_thumbnail(self, default_top_level, default_sub_type, width, - height, content_type, method): + def __init__(self, primary_base_path): + self.base_path = primary_base_path + + def default_thumbnail_rel(self, default_top_level, default_sub_type, width, + height, content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "default_thumbnails", default_top_level, + "default_thumbnails", default_top_level, default_sub_type, file_name ) - def local_media_filepath(self, media_id): + default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) + + def local_media_filepath_rel(self, media_id): return os.path.join( - self.base_path, "local_content", + "local_content", media_id[0:2], media_id[2:4], media_id[4:] ) - def local_media_thumbnail(self, media_id, width, height, content_type, - method): + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + def local_media_thumbnail_rel(self, media_id, width, height, content_type, + method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "local_thumbnails", + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) - def remote_media_filepath(self, server_name, file_id): + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - self.base_path, "remote_content", server_name, + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) - def remote_media_thumbnail(self, server_name, file_id, width, height, - content_type, method): + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + def remote_media_thumbnail_rel(self, server_name, file_id, width, height, + content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], file_name ) + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], ) + + def url_cache_filepath_rel(self, media_id): + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + media_id[:10], media_id[11:] + ) + else: + return os.path.join( + "url_cache", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + def url_cache_filepath_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache", + media_id[0:2], + ), + ] + + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, + method): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % ( + width, height, top_level_type, sub_type, method + ) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + media_id[:10], media_id[11:], + file_name + ) + else: + return os.path.join( + "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + file_name + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + def url_cache_thumbnail_directory(self, media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ) + else: + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + def url_cache_thumbnail_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], + ), + ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 3cbeca503c..9800ce7581 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,26 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer, threads +import twisted.internet.error +import twisted.web.http +from twisted.web.resource import Resource + +from ._base import respond_404, FileInfo, respond_with_responder from .upload_resource import UploadResource from .download_resource import DownloadResource from .thumbnail_resource import ThumbnailResource from .identicon_resource import IdenticonResource from .preview_url_resource import PreviewUrlResource from .filepath import MediaFilePaths - -from twisted.web.resource import Resource - from .thumbnailer import Thumbnailer +from .storage_provider import StorageProviderWrapper +from .media_storage import MediaStorage from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.util.stringutils import random_string -from synapse.api.errors import SynapseError - -from twisted.internet import defer, threads +from synapse.api.errors import ( + SynapseError, HttpResponseException, NotFoundError, FederationDeniedError, +) from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable +from synapse.util.retryutils import NotRetryingDestination import os import errno @@ -40,12 +47,12 @@ import shutil import cgi import logging -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) -UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository(object): @@ -57,46 +64,90 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = MediaFilePaths(hs.config.media_store_path) + + self.primary_base_path = hs.config.media_store_path + self.filepaths = MediaFilePaths(self.primary_base_path) + self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() + self.recently_accessed_locals = set() + + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] + + for clz, provider_config, wrapper_config in hs.config.media_storage_providers: + backend = clz(hs, provider_config) + provider = StorageProviderWrapper( + backend, + store_local=wrapper_config.store_local, + store_remote=wrapper_config.store_remote, + store_synchronous=wrapper_config.store_synchronous, + ) + storage_providers.append(provider) + + self.media_storage = MediaStorage( + self.primary_base_path, self.filepaths, storage_providers, + ) self.clock.looping_call( - self._update_recently_accessed_remotes, - UPDATE_RECENTLY_ACCESSED_REMOTES_TS + self._update_recently_accessed, + UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks - def _update_recently_accessed_remotes(self): - media = self.recently_accessed_remotes + def _update_recently_accessed(self): + remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() + yield self.store.update_cached_last_access_time( - media, self.clock.time_msec() + local_media, remote_media, self.clock.time_msec() ) - @staticmethod - def _makedirs(filepath): - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) + def mark_recently_accessed(self, server_name, media_id): + """Mark the given media as recently accessed. + + Args: + server_name (str|None): Origin server of media, or None if local + media_id (str): The media ID of the content + """ + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type(str): The content type of the file + upload_name(str): The name of the file + content: A file like object that is the content to store + content_length(int): The length of the content + auth_user(str): The user_id of the uploader + + Returns: + Deferred[str]: The mxc url of the stored content + """ media_id = random_string(24) - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) + file_info = FileInfo( + server_name=None, + file_id=media_id, + ) - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) + fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) @@ -108,104 +159,275 @@ class MediaRepository(object): media_length=content_length, user_id=auth_user, ) - media_info = { - "media_type": media_type, - "media_length": content_length, - } - yield self._generate_local_thumbnails(media_id, media_info) + yield self._generate_thumbnails( + None, media_id, media_id, media_type, + ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks - def get_remote_media(self, server_name, media_id): + def get_local_media(self, request, media_id, name): + """Responds to reqests for local media, if exists, or returns 404. + + Args: + request(twisted.web.http.Request) + media_id (str): The media ID of the content. (This is the same as + the file_id for local content.) + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + media_info = yield self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo( + None, media_id, + url_cache=url_cache, + ) + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + + @defer.inlineCallbacks + def get_remote_media(self, request, server_name, media_id, name): + """Respond to requests for remote media. + + Args: + request(twisted.web.http.Request) + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + with (yield self.remote_media_linearizer.queue(key)): + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + else: + respond_404(request) + + @defer.inlineCallbacks + def get_remote_media_info(self, server_name, media_id): + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[dict]: The media_info of the file + """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): - media_info = yield self._get_remote_media_impl(server_name, media_id) + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[(Responder, media_info)] + """ media_info = yield self.store.get_cached_remote_media( server_name, media_id ) - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise genereate a new + # one. + if media_info: + file_id = media_info["filesystem_id"] else: - self.recently_accessed_remotes.add((server_name, media_id)) - yield self.store.update_cached_last_access_time( - [(server_name, media_id)], self.clock.time_msec() - ) - defer.returnValue(media_info) + file_id = random_string(24) - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id): - file_id = random_string(24) + file_info = FileInfo(server_name, file_id) + + # If we have an entry in the DB, try and look for it + if media_info: + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + defer.returnValue((responder, media_info)) + + # Failed to find the file anywhere, lets download it. - fname = self.filepaths.remote_media_filepath( - server_name, file_id + media_info = yield self._download_remote_file( + server_name, media_id, file_id ) - self._makedirs(fname) - try: - with open(fname, "wb") as f: - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + responder = yield self.media_storage.fetch_media(file_info) + defer.returnValue((responder, media_info)) + + @defer.inlineCallbacks + def _download_remote_file(self, server_name, media_id, file_id): + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name (str): Originating server + media_id (str): The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id (str): Local file ID + + Returns: + Deferred[MediaInfo] + """ + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + ) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join(( + "/_matrix/media/v1/download", server_name, media_id, + )) + try: + length, headers = yield self.client.get_file( + server_name, request_path, output_stream=f, + max_size=self.max_upload_size, args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false", + } + ) + except twisted.internet.error.DNSLookupError as e: + logger.warn("HTTP error fetching remote media %s/%s: %r", + server_name, media_id, e) + raise NotFoundError() + + except HttpResponseException as e: + logger.warn("HTTP error fetching remote media %s/%s: %s", + server_name, media_id, e.response) + if e.code == twisted.web.http.NOT_FOUND: + raise SynapseError.from_http_response_exception(e) + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise + except NotRetryingDestination: + logger.warn("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise SynapseError(502, "Failed to fetch remote media") + + yield finish() + + media_type = headers["Content-Type"][0] + + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get("filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith("utf-8''"): + upload_name = upload_name_utf8[7:] + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get("filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii + + if upload_name: + upload_name = urlparse.unquote(upload_name) try: - length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, - ) - except Exception as e: - logger.warn("Failed to fetch remoted media %r", e) - raise SynapseError(502, "Failed to fetch remoted media") - - media_type = headers["Content-Type"][0] - time_now_ms = self.clock.time_msec() - - content_disposition = headers.get("Content-Disposition", None) - if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith("utf-8''"): - upload_name = upload_name_utf8[7:] - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get("filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii - - if upload_name: - upload_name = urlparse.unquote(upload_name) - try: - upload_name = upload_name.decode("utf-8") - except UnicodeDecodeError: - upload_name = None - else: - upload_name = None - - logger.info("Stored remote media in file %r", fname) - - yield self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - except: - os.remove(fname) - raise + upload_name = upload_name.decode("utf-8") + except UnicodeDecodeError: + upload_name = None + else: + upload_name = None + + logger.info("Stored remote media in file %r", fname) + + yield self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) media_info = { "media_type": media_type, @@ -215,8 +437,8 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_remote_thumbnails( - server_name, media_id, media_info + yield self._generate_thumbnails( + server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) @@ -224,9 +446,8 @@ class MediaRepository(object): def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): - thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -238,69 +459,126 @@ class MediaRepository(object): return if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + t_width, t_height = thumbnailer.aspect(t_width, t_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: - t_len = None + t_byte_source = None - return t_len + return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type): - input_path = self.filepaths.local_media_filepath(media_id) - - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) + t_method, t_type, url_cache): + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + None, media_id, url_cache=url_cache, + )) - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=None, + file_id=media_id, + url_cache=url_cache, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + server_name, file_id, url_cache=False, + )) - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=server_name, + file_id=media_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info): - media_type = media_info["media_type"] + def _generate_thumbnails(self, server_name, media_id, file_id, media_type, + url_cache=False): + """Generate and store thumbnails for an image. + + Args: + server_name (str|None): The server name if remote media, else None if local + media_id (str): The media ID of the content. (This is the same as + the file_id for local content) + file_id (str): Local file ID + media_type (str): The content type of the file + url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Deferred[dict]: Dict with "width" and "height" keys of original image + """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return - input_path = self.filepaths.local_media_filepath(media_id) + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + server_name, file_id, url_cache=url_cache, + )) + thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -312,125 +590,68 @@ class MediaRepository(object): ) return - local_thumbnails = [] - - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails = {} + for r_width, r_height, r_method, r_type in requirements: + if r_method == "crop": + thumbnails.setdefault((r_width, r_height, r_type), r_method) + elif r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, r_type)] = r_method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.crop, + t_width, t_height, t_type, )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + elif t_method == "scale": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.scale, + t_width, t_height, t_type, )) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + try: + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - remote_thumbnails = [] + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height + t_len = os.path.getsize(output_path) - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ + # Write to database + if server_name: + yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) + else: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) defer.returnValue({ "width": m_width, @@ -451,6 +672,8 @@ class MediaRepository(object): logger.info("Deleting: %r", key) + # TODO: Should we delete from the backup store + with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: @@ -525,7 +748,11 @@ class MediaRepositoryResource(Resource): self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) + self.putChild("thumbnail", ThumbnailResource( + hs, media_repo, media_repo.media_storage, + )) self.putChild("identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) + self.putChild("preview_url", PreviewUrlResource( + hs, media_repo, media_repo.media_storage, + )) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py new file mode 100644 index 0000000000..d23fe10b07 --- /dev/null +++ b/synapse/rest/media/v1/media_storage.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vecotr Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, threads +from twisted.protocols.basic import FileSender + +import six + +from ._base import Responder + +from synapse.util.file_consumer import BackgroundFileConsumer +from synapse.util.logcontext import make_deferred_yieldable + +import contextlib +import os +import logging +import shutil +import sys + + +logger = logging.getLogger(__name__) + + +class MediaStorage(object): + """Responsible for storing/fetching files from local sources. + + Args: + local_media_directory (str): Base path where we store media on disk + filepaths (MediaFilePaths) + storage_providers ([StorageProvider]): List of StorageProvider that are + used to fetch and store files. + """ + + def __init__(self, local_media_directory, filepaths, storage_providers): + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + + @defer.inlineCallbacks + def store_file(self, source, file_info): + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info (FileInfo): Info about the file to store + + Returns: + Deferred[str]: the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + _write_file_synchronously, source, f, + )) + yield finish_cb() + + defer.returnValue(fname) + + @contextlib.contextmanager + def store_into_file(self, file_info): + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns a Deferred. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info (FileInfo): Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + yield finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + finished_called = [False] + + @defer.inlineCallbacks + def finish(): + for provider in self.storage_providers: + yield provider.store_file(path, file_info) + + finished_called[0] = True + + try: + with open(fname, "wb") as f: + yield f, fname, finish + except Exception: + t, v, tb = sys.exc_info() + try: + os.remove(fname) + except Exception: + pass + six.reraise(t, v, tb) + + if not finished_called: + raise Exception("Finished callback not called") + + @defer.inlineCallbacks + def fetch_media(self, file_info): + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info (FileInfo) + + Returns: + Deferred[Responder|None]: Returns a Responder if the file was found, + otherwise None. + """ + + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + defer.returnValue(FileResponder(open(local_path, "rb"))) + + for provider in self.storage_providers: + res = yield provider.fetch(path, file_info) + if res: + defer.returnValue(res) + + defer.returnValue(None) + + @defer.inlineCallbacks + def ensure_media_is_in_local_cache(self, file_info): + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info (FileInfo) + + Returns: + Deferred[str]: Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + defer.returnValue(local_path) + + dirname = os.path.dirname(local_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + for provider in self.storage_providers: + res = yield provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer(open(local_path, "w")) + yield res.write_to_consumer(consumer) + yield consumer.wait() + defer.returnValue(local_path) + + raise Exception("file could not be found") + + def _file_info_to_path(self, file_info): + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + + Args: + file_info (FileInfo) + + Returns: + str + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id, + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.local_media_filepath_rel( + file_info.file_id, + ) + + +def _write_file_synchronously(source, dest): + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file (file): A file like object to be streamed ot the client, + is closed when finished streaming. + """ + def __init__(self, open_file): + self.open_file = open_file + + def write_to_consumer(self, consumer): + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.open_file.close() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 99760d622f..9290d7946f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -12,39 +12,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import cgi +import datetime +import errno +import fnmatch +import itertools +import logging +import os +import re +import shutil +import sys +import traceback +import simplejson as json +import urlparse from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from twisted.web.resource import Resource +from ._base import FileInfo + from synapse.api.errors import ( SynapseError, Codes, ) +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - request_handler, respond_with_json_bytes + request_handler, respond_with_json_bytes, + respond_with_json, ) from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii -import os -import re -import fnmatch -import cgi -import ujson as json -import urlparse -import itertools - -import logging logger = logging.getLogger(__name__) class PreviewUrlResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.auth = hs.get_auth() @@ -56,19 +64,27 @@ class PreviewUrlResource(Resource): self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path + self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist - # simple memory cache mapping urls to OG metadata - self.cache = ExpiringCache( + # memory cache mapping urls to an ObservableDeferred returning + # JSON-encoded OG metadata + self._cache = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self.cache.start() + self._cache.start() + + self._cleaner_loop = self.clock.looping_call( + self._expire_url_cache_data, 10 * 1000 + ) - self.downloads = {} + def render_OPTIONS(self, request): + return respond_with_json(request, 200, {}, send_cors=True) def render_GET(self, request): self._async_render_GET(request) @@ -86,6 +102,7 @@ class PreviewUrlResource(Resource): else: ts = self.clock.time_msec() + # XXX: we could move this into _do_preview if we wanted. url_tuple = urlparse.urlsplit(url) for entry in self.url_preview_url_blacklist: match = True @@ -118,53 +135,62 @@ class PreviewUrlResource(Resource): Codes.UNKNOWN ) - # first check the memory cache - good to handle all the clients on this - # HS thundering away to preview the same URL at the same time. - og = self.cache.get(url) - if og: - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - return + # the in-memory cache: + # * ensures that only one request is active at a time + # * takes load off the DB for the thundering herds + # * also caches any failures (unlike the DB) so we don't keep + # requesting the same endpoint + + observable = self._cache.get(url) + + if not observable: + download = run_in_background( + self._do_preview, + url, requester.user, ts, + ) + observable = ObservableDeferred( + download, + consumeErrors=True + ) + self._cache[url] = observable + else: + logger.info("Returning cached response") - # then check the URL cache in the DB (which will also provide us with + og = yield make_deferred_yieldable(observable.observe()) + respond_with_json_bytes(request, 200, og, send_cors=True) + + @defer.inlineCallbacks + def _do_preview(self, url, user, ts): + """Check the db, and download the URL and build a preview + + Args: + url (str): + user (str): + ts (int): + + Returns: + Deferred[str]: json-encoded og data + """ + # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( cache_result and - cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["expires_ts"] > ts and cache_result["response_code"] / 100 == 2 ): - respond_with_json_bytes( - request, 200, cache_result["og"].encode('utf-8'), - send_cors=True - ) + defer.returnValue(cache_result["og"]) return - # Ensure only one download for a given URL is active at a time - download = self.downloads.get(url) - if download is None: - download = self._download_url(url, requester.user) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[url] = download - - @download.addBoth - def callback(media_info): - del self.downloads[url] - return media_info - media_info = yield download.observe() - - # FIXME: we should probably update our cache now anyway, so that - # even if the OG calculation raises, we don't keep hammering on the - # remote server. For now, leave it uncached to aid debugging OG - # calculation problems + media_info = yield self._download_url(url, user) logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): - dims = yield self.media_repo._generate_local_thumbnails( - media_info['filesystem_id'], media_info + file_id = media_info['filesystem_id'] + dims = yield self.media_repo._generate_thumbnails( + None, file_id, file_id, media_info["media_type"], + url_cache=True, ) og = { @@ -204,13 +230,15 @@ class PreviewUrlResource(Resource): # just rely on the caching on the master request to speed things up. if 'og:image' in og and og['og:image']: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), requester.user + _rebase_url(og['og:image'], media_info['uri']), user ) if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info + file_id = image_info['filesystem_id'] + dims = yield self.media_repo._generate_thumbnails( + None, file_id, file_id, image_info["media_type"], + url_cache=True, ) if dims: og["og:image:width"] = dims['width'] @@ -231,21 +259,20 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - # store OG in ephemeral in-memory cache - self.cache[url] = og + jsonog = json.dumps(og) # store OG in history-aware DB cache yield self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], - media_info["expires"], - json.dumps(og), + media_info["expires"] + media_info["created_ts"], + jsonog, media_info["filesystem_id"], media_info["created_ts"], ) - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + defer.returnValue(jsonog) @defer.inlineCallbacks def _download_url(self, url, user): @@ -253,21 +280,36 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - # XXX: horrible duplication with base_resource's _download_remote_file() - file_id = random_string(24) + file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fname = self.filepaths.local_media_filepath(file_id) - self.media_repo._makedirs(fname) + file_info = FileInfo( + server_name=None, + file_id=file_id, + url_cache=True, + ) - try: - with open(fname, "wb") as f: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + except Exception as e: # FIXME: pass through 404s and other error messages nicely + logger.warn("Error downloading %s: %r", url, e) + raise SynapseError( + 500, "Failed to download content: %s" % ( + traceback.format_exception_only(sys.exc_type, e), + ), + Codes.UNKNOWN, + ) + yield finish() - media_type = headers["Content-Type"][0] + try: + if "Content-Type" in headers: + media_type = headers["Content-Type"][0] + else: + media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) @@ -303,14 +345,15 @@ class PreviewUrlResource(Resource): upload_name=download_name, media_length=length, user_id=user, + url_cache=url, ) except Exception as e: - os.remove(fname) - raise SynapseError( - 500, ("Failed to download content: %s" % e), - Codes.UNKNOWN - ) + logger.error("Error handling downloaded %s: %r", url, e) + # TODO: we really ought to delete the downloaded file in this + # case, since we won't have recorded it in the db, and will + # therefore not expire it. + raise defer.returnValue({ "media_type": media_type, @@ -327,6 +370,95 @@ class PreviewUrlResource(Resource): "etag": headers["ETag"][0] if "ETag" in headers else None, }) + @defer.inlineCallbacks + def _expire_url_cache_data(self): + """Clean up expired url cache content, media and thumbnails. + """ + # TODO: Delete from backup media store + + now = self.clock.time_msec() + + logger.info("Running url preview cache expiry") + + if not (yield self.store.has_completed_background_updates()): + logger.info("Still running DB updates; skipping expiry") + return + + # First we delete expired url cache entries + media_ids = yield self.store.get_expired_url_cache(now) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache(removed_media) + + if removed_media: + logger.info("Deleted %d entries from url cache", len(removed_media)) + + # Now we delete old images associated with the url cache. + # These may be cached for a bit on the client (i.e., they + # may have a room open with a preview url thing open). + # So we wait a couple of days before deleting, just in case. + expire_before = now - 2 * 24 * 60 * 60 * 1000 + media_ids = yield self.store.get_url_cache_media_before(expire_before) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) + try: + shutil.rmtree(thumbnail_dir) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache_media(removed_media) + + logger.info("Deleted %d media from url cache", len(removed_media)) + def decode_and_calc_og(body, media_uri, request_encoding=None): from lxml import etree @@ -424,7 +556,14 @@ def _calc_og(tree, media_uri): from lxml import etree TAGS_TO_REMOVE = ( - "header", "nav", "aside", "footer", "script", "style", etree.Comment + "header", + "nav", + "aside", + "footer", + "script", + "noscript", + "style", + etree.Comment ) # Split all the text nodes into paragraphs (by splitting on new @@ -434,6 +573,8 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og['og:description'] = summarize_paragraphs(text_nodes) + else: + og['og:description'] = summarize_paragraphs([og['og:description']]) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py new file mode 100644 index 0000000000..0252afd9d3 --- /dev/null +++ b/synapse/rest/media/v1/storage_provider.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer, threads + +from .media_storage import FileResponder + +from synapse.config._base import Config +from synapse.util.logcontext import run_in_background + +import logging +import os +import shutil + + +logger = logging.getLogger(__name__) + + +class StorageProvider(object): + """A storage provider is a service that can store uploaded media and + retrieve them. + """ + def store_file(self, path, file_info): + """Store the file described by file_info. The actual contents can be + retrieved by reading the file in file_info.upload_path. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred + """ + pass + + def fetch(self, path, file_info): + """Attempt to fetch the file described by file_info and stream it + into writer. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred(Responder): Returns a Responder if the provider has the file, + otherwise returns None. + """ + pass + + +class StorageProviderWrapper(StorageProvider): + """Wraps a storage provider and provides various config options + + Args: + backend (StorageProvider) + store_local (bool): Whether to store new local files or not. + store_synchronous (bool): Whether to wait for file to be successfully + uploaded, or todo the upload in the backgroud. + store_remote (bool): Whether remote media should be uploaded + """ + def __init__(self, backend, store_local, store_synchronous, store_remote): + self.backend = backend + self.store_local = store_local + self.store_synchronous = store_synchronous + self.store_remote = store_remote + + def store_file(self, path, file_info): + if not file_info.server_name and not self.store_local: + return defer.succeed(None) + + if file_info.server_name and not self.store_remote: + return defer.succeed(None) + + if self.store_synchronous: + return self.backend.store_file(path, file_info) + else: + # TODO: Handle errors. + def store(): + try: + return self.backend.store_file(path, file_info) + except Exception: + logger.exception("Error storing file") + run_in_background(store) + return defer.succeed(None) + + def fetch(self, path, file_info): + return self.backend.fetch(path, file_info) + + +class FileStorageProviderBackend(StorageProvider): + """A storage provider that stores files in a directory on a filesystem. + + Args: + hs (HomeServer) + config: The config returned by `parse_config`. + """ + + def __init__(self, hs, config): + self.cache_directory = hs.config.media_store_path + self.base_directory = config + + def store_file(self, path, file_info): + """See StorageProvider.store_file""" + + primary_fname = os.path.join(self.cache_directory, path) + backup_fname = os.path.join(self.base_directory, path) + + dirname = os.path.dirname(backup_fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + return threads.deferToThread( + shutil.copyfile, primary_fname, backup_fname, + ) + + def fetch(self, path, file_info): + """See StorageProvider.fetch""" + + backup_fname = os.path.join(self.base_directory, path) + if os.path.isfile(backup_fname): + return FileResponder(open(backup_fname, "rb")) + + @staticmethod + def parse_config(config): + """Called on startup to parse config supplied. This should parse + the config and raise if there is a problem. + + The returned value is passed into the constructor. + + In this case we only care about a single param, the directory, so let's + just pull that out. + """ + return Config.ensure_directory(config["directory"]) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d8f54adc99..58ada49711 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -14,7 +14,10 @@ # limitations under the License. -from ._base import parse_media_id, respond_404, respond_with_file +from ._base import ( + parse_media_id, respond_404, respond_with_file, FileInfo, + respond_with_responder, +) from twisted.web.resource import Resource from synapse.http.servlet import parse_string, parse_integer from synapse.http.server import request_handler, set_cors_headers @@ -30,12 +33,12 @@ logger = logging.getLogger(__name__) class ThumbnailResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.store = hs.get_datastore() - self.filepaths = media_repo.filepaths self.media_repo = media_repo + self.media_storage = media_storage self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname self.version_string = hs.version_string @@ -64,6 +67,7 @@ class ThumbnailResource(Resource): yield self._respond_local_thumbnail( request, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(None, media_id) else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( @@ -75,6 +79,7 @@ class ThumbnailResource(Resource): request, server_name, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(server_name, media_id) @defer.inlineCallbacks def _respond_local_thumbnail(self, request, media_id, width, height, @@ -84,11 +89,10 @@ class ThumbnailResource(Resource): if not media_info: respond_404(request) return - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + respond_404(request) + return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) @@ -96,20 +100,25 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - file_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method, + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], ) - yield respond_with_file(request, t_type, file_path) + t_type = file_info.thumbnail_type + t_length = thumbnail_info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, - ) + logger.info("Couldn't find any generated thumbnails") + respond_404(request) @defer.inlineCallbacks def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, @@ -120,11 +129,10 @@ class ThumbnailResource(Resource): if not media_info: respond_404(request) return - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + respond_404(request) + return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: @@ -134,37 +142,43 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - file_path = self.filepaths.local_media_thumbnail( - media_id, desired_width, desired_height, desired_type, desired_method, + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], ) - yield respond_with_file(request, desired_type, file_path) - return - logger.debug("We don't have a local thumbnail of that size. Generating") + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( - media_id, desired_width, desired_height, desired_method, desired_type + media_id, desired_width, desired_height, desired_method, desired_type, + url_cache=media_info["url_cache"], ) if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + logger.warn("Failed to generate thumbnail") + respond_404(request) @defer.inlineCallbacks def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, desired_width, desired_height, desired_method, desired_type): - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -179,14 +193,24 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, desired_width, desired_height, - desired_type, desired_method, + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], ) - yield respond_with_file(request, desired_type, file_path) - return - logger.debug("We don't have a local thumbnail of that size. Generating") + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( @@ -197,22 +221,16 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + logger.warn("Failed to generate thumbnail") + respond_404(request) @defer.inlineCallbacks def _respond_remote_thumbnail(self, request, server_name, media_id, width, height, method, m_type): # TODO: Don't download the whole remote file - # We should proxy the thumbnail from the remote server instead. - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + # We should proxy the thumbnail from the remote server instead of + # downloading the remote file and generating our own thumbnails. + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -222,59 +240,23 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - file_id = thumbnail_info["filesystem_id"] + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, - ) - - @defer.inlineCallbacks - def _respond_default_thumbnail(self, request, media_info, width, height, - method, m_type): - # XXX: how is this meant to work? store.get_default_thumbnails - # appears to always return [] so won't this always 404? - media_type = media_info["media_type"] - top_level_type = media_type.split("/")[0] - sub_type = media_type.split("/")[-1].split(";")[0] - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, sub_type, - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, "_default", - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - "_default", "_default", - ) - if not thumbnail_infos: + logger.info("Failed to find any generated thumbnails") respond_404(request) - return - - thumbnail_info = self._select_thumbnail( - width, height, "crop", m_type, thumbnail_infos - ) - - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - t_length = thumbnail_info["thumbnail_length"] - - file_path = self.filepaths.default_thumbnail( - top_level_type, sub_type, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) def _select_thumbnail(self, desired_width, desired_height, desired_method, desired_type, thumbnail_infos): diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3868d4f65f..e1ee535b9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -50,12 +50,16 @@ class Thumbnailer(object): else: return ((max_height * self.width) // self.height, max_height) - def scale(self, output_path, width, height, output_type): - """Rescales the image to the given dimensions""" + def scale(self, width, height, output_type): + """Rescales the image to the given dimensions. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk + """ scaled = self.image.resize((width, height), Image.ANTIALIAS) - return self.save_image(scaled, output_type, output_path) + return self._encode_image(scaled, output_type) - def crop(self, output_path, width, height, output_type): + def crop(self, width, height, output_type): """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -65,6 +69,9 @@ class Thumbnailer(object): Args: max_width: The largest possible width. max_height: The larget possible height. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width @@ -82,13 +89,9 @@ class Thumbnailer(object): crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self.save_image(cropped, output_type, output_path) + return self._encode_image(cropped, output_type) - def save_image(self, output_image, output_type, output_path): + def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) - output_bytes = output_bytes_io.getvalue() - with open(output_path, "wb") as output_file: - output_file.write(output_bytes) - logger.info("Stored thumbnail in file %r", output_path) - return len(output_bytes) + return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 4ab33f73bf..a31e75cb46 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -81,19 +81,19 @@ class UploadResource(Resource): headers = request.requestHeaders if headers.hasHeader("Content-Type"): - media_type = headers.getRawHeaders("Content-Type")[0] + media_type = headers.getRawHeaders(b"Content-Type")[0] else: raise SynapseError( msg="Upload request missing 'Content-Type'", code=400, ) - # if headers.hasHeader("Content-Disposition"): - # disposition = headers.getRawHeaders("Content-Disposition")[0] + # if headers.hasHeader(b"Content-Disposition"): + # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content.read(), + media_type, upload_name, request.content, content_length, requester.user ) diff --git a/synapse/server.py b/synapse/server.py index c577032041..ebdea6b0c4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,29 +31,47 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory -from synapse.federation import initialize_http_replication +from synapse.events.spamcheck import SpamChecker +from synapse.federation.federation_client import FederationClient +from synapse.federation.federation_server import FederationServer from synapse.federation.send_queue import FederationRemoteSendQueue +from synapse.federation.federation_server import FederationHandlerRegistry from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGeneartor +from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.room_member import RoomMemberMasterHandler +from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler +from synapse.handlers.read_marker import ReadMarkerHandler +from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.profile import ProfileHandler +from synapse.handlers.message import EventCreationHandler +from synapse.groups.groups_server import GroupsServerHandler +from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier +from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool -from synapse.rest.media.v1.media_repository import MediaRepository -from synapse.state import StateHandler +from synapse.rest.media.v1.media_repository import ( + MediaRepository, + MediaRepositoryResource, +) +from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStore from synapse.streams.events import EventSources from synapse.util import Clock @@ -82,18 +100,14 @@ class HomeServer(object): """ DEPENDENCIES = [ - 'config', - 'clock', 'http_client', 'db_pool', - 'persistence_service', - 'replication_layer', - 'datastore', + 'federation_client', + 'federation_server', 'handlers', - 'v1auth', 'auth', - 'rest_servlet_factory', 'state_handler', + 'state_resolution_handler', 'presence_handler', 'sync_handler', 'typing_handler', @@ -108,19 +122,12 @@ class HomeServer(object): 'application_service_scheduler', 'application_service_handler', 'device_message_handler', + 'profile_handler', + 'event_creation_handler', + 'deactivate_account_handler', + 'set_password_handler', 'notifier', - 'distributor', - 'client_resource', - 'resource_for_federation', - 'resource_for_static_content', - 'resource_for_web_client', - 'resource_for_content_repo', - 'resource_for_server_key', - 'resource_for_server_key_v2', - 'resource_for_media_repository', - 'resource_for_metrics', 'event_sources', - 'ratelimiter', 'keyring', 'pusherpool', 'event_builder_factory', @@ -128,10 +135,22 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'media_repository_resource', 'federation_transport_client', 'federation_sender', 'receipts_handler', 'macaroon_generator', + 'tcp_replication', + 'read_marker_handler', + 'action_generator', + 'user_directory_handler', + 'groups_local_handler', + 'groups_server_handler', + 'groups_attestation_signing', + 'groups_attestation_renewer', + 'spam_checker', + 'room_member_handler', + 'federation_registry', ] def __init__(self, hostname, **kwargs): @@ -165,8 +184,26 @@ class HomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname - def build_replication_layer(self): - return initialize_http_replication(self) + def get_clock(self): + return self.clock + + def get_datastore(self): + return self.datastore + + def get_config(self): + return self.config + + def get_distributor(self): + return self.distributor + + def get_ratelimiter(self): + return self.ratelimiter + + def build_federation_client(self): + return FederationClient(self) + + def build_federation_server(self): + return FederationServer(self) def build_handlers(self): return Handlers(self) @@ -187,18 +224,12 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf - def build_state_handler(self): return StateHandler(self) + def build_state_resolution_handler(self): + return StateResolutionHandler(self) + def build_presence_handler(self): return PresenceHandler(self) @@ -244,6 +275,18 @@ class HomeServer(object): def build_initial_sync_handler(self): return InitialSyncHandler(self) + def build_profile_handler(self): + return ProfileHandler(self) + + def build_event_creation_handler(self): + return EventCreationHandler(self) + + def build_deactivate_account_handler(self): + return DeactivateAccountHandler(self) + + def build_set_password_handler(self): + return SetPasswordHandler(self) + def build_event_sources(self): return EventSources(self) @@ -273,6 +316,28 @@ class HomeServer(object): **self.db_config.get("args", {}) ) + def get_db_conn(self, run_new_connection=True): + """Makes a new connection to the database, skipping the db pool + + Returns: + Connection: a connection object implementing the PEP-249 spec + """ + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def build_media_repository_resource(self): + # build the media repo resource. This indirects through the HomeServer + # to ensure that we only have a single instance of + return MediaRepositoryResource(self) + def build_media_repository(self): return MediaRepository(self) @@ -290,6 +355,41 @@ class HomeServer(object): def build_receipts_handler(self): return ReceiptsHandler(self) + def build_read_marker_handler(self): + return ReadMarkerHandler(self) + + def build_tcp_replication(self): + raise NotImplementedError() + + def build_action_generator(self): + return ActionGenerator(self) + + def build_user_directory_handler(self): + return UserDirectoryHandler(self) + + def build_groups_local_handler(self): + return GroupsLocalHandler(self) + + def build_groups_server_handler(self): + return GroupsServerHandler(self) + + def build_groups_attestation_signing(self): + return GroupAttestationSigning(self) + + def build_groups_attestation_renewer(self): + return GroupAttestionRenewer(self) + + def build_spam_checker(self): + return SpamChecker(self) + + def build_room_member_handler(self): + if self.config.worker_app: + return RoomMemberWorkerHandler(self) + return RoomMemberMasterHandler(self) + + def build_federation_registry(self): + return FederationHandlerRegistry() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9570df5537..c3a9a3847b 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,10 +1,16 @@ import synapse.api.auth +import synapse.federation.transaction_queue +import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth +import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys -import synapse.storage +import synapse.handlers.set_password +import synapse.rest.media.v1.media_repository import synapse.state +import synapse.storage + class HomeServer(object): def get_auth(self) -> synapse.api.auth.Auth: @@ -27,3 +33,24 @@ class HomeServer(object): def get_state_handler(self) -> synapse.state.StateHandler: pass + + def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: + pass + + def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + pass + + def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + pass + + def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: + pass + + def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + pass + + def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + pass + + def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + pass diff --git a/synapse/state.py b/synapse/state.py index 383d32b163..26093c8434 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer +from synapse.util.caches import CACHE_SIZE_FACTOR from collections import namedtuple from frozendict import frozendict import logging import hashlib -import os logger = logging.getLogger(__name__) @@ -38,9 +38,6 @@ logger = logging.getLogger(__name__) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) - - SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -61,7 +58,11 @@ class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] def __init__(self, state, state_group, prev_group=None, delta_ids=None): + # dict[(str, str), str] map from (type, state_key) to event_id self.state = frozendict(state) + + # the ID of a state group if one and only one is involved. + # otherwise, None otherwise? self.state_group = state_group self.prev_group = prev_group @@ -84,31 +85,19 @@ class _StateCacheEntry(object): class StateHandler(object): - """ Responsible for doing state conflict resolution. + """Fetches bits of state from the stores, and does state resolution + where necessary """ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs - - # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = None - self.resolve_linearizer = Linearizer(name="state_resolve_lock") + self._state_resolution_handler = hs.get_state_resolution_handler() def start_caching(self): - logger.debug("start_caching") - - self._state_cache = ExpiringCache( - cache_name="state_cache", - clock=self.clock, - max_len=SIZE_OF_CACHE, - expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, - iterable=True, - reset_expiry_on_get=True, - ) - - self._state_cache.start() + # TODO: remove this shim + self._state_resolution_handler.start_caching() @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", @@ -130,7 +119,7 @@ class StateHandler(object): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: @@ -143,25 +132,33 @@ class StateHandler(object): state_map = yield self.store.get_events(state.values(), get_prev_content=False) state = { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map + key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map } defer.returnValue(state) @defer.inlineCallbacks - def get_current_state_ids(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state_ids(self, room_id, latest_event_ids=None): + """Get the current state, or the state at a set of events, for a room + + Args: + room_id (str): + + latest_event_ids (iterable[str]|None): if given, the forward + extremities to resolve. If None, we look them up from the + database (via a cache) + + Returns: + Deferred[dict[(str, str), str)]]: the state dict, mapping from + (event_type, state_key) -> event_id + """ if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state - if event_type: - defer.returnValue(state.get((event_type, state_key))) - return - defer.returnValue(state) @defer.inlineCallbacks @@ -169,54 +166,71 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_user_in_room") - entry = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state( - room_id, entry.state_id, entry.state - ) + entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_users = yield self.store.get_joined_users_from_state(room_id, entry) defer.returnValue(joined_users) @defer.inlineCallbacks + def get_current_hosts_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + logger.debug("calling resolve_state_groups from get_current_hosts_in_room") + entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_hosts = yield self.store.get_joined_hosts(room_id, entry) + defer.returnValue(joined_hosts) + + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): - """ Fills out the context with the `current state` of the graph. The - `current state` here is defined to be the state of the event graph - just before the event - i.e. it never includes `event` + """Build an EventContext structure for the event. - If `event` has `auth_events` then this will also fill out the - `auth_events` field on `context` from the `current_state`. + This works out what the current state should be for the event, and + generates a new state group if necessary. Args: - event (EventBase) + event (synapse.events.EventBase): + old_state (dict|None): The state at the event if it can't be + calculated from existing events. This is normally only specified + when receiving an event from federation where we don't have the + prev events for, e.g. when backfilling. Returns: - an EventContext + synapse.events.snapshot.EventContext: """ - context = EventContext() if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + context = EventContext() if old_state: context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } if event.is_state(): - context.current_state_events = dict(context.prev_state_ids) + context.current_state_ids = dict(context.prev_state_ids) key = (event.type, event.state_key) - context.current_state_events[key] = event.event_id + context.current_state_ids[key] = event.event_id else: - context.current_state_events = context.prev_state_ids + context.current_state_ids = context.prev_state_ids else: context.current_state_ids = {} context.prev_state_ids = {} context.prev_state_events = [] - context.state_group = self.store.get_next_state_group() + + # We don't store state for outliers, so we don't generate a state + # froup for it. + context.state_group = None + defer.returnValue(context) if old_state: + # We already have the state, so we don't need to calculate it. + # Let's just correctly fill out the context and create a + # new state group for it. + + context = EventContext() context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) @@ -229,26 +243,29 @@ class StateHandler(object): else: context.current_state_ids = context.prev_state_ids + context.state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=None, + delta_ids=None, + current_state_ids=context.current_state_ids, + ) + context.prev_state_events = [] defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") - if event.is_state(): - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - event_type=event.type, - state_key=event.state_key, - ) - else: - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - ) + entry = yield self.resolve_state_groups_for_events( + event.room_id, [e for e, _ in event.prev_events], + ) curr_state = entry.state + context = EventContext() context.prev_state_ids = curr_state if event.is_state(): - context.state_group = self.store.get_next_state_group() + # If this is a state event then we need to create a new state + # group for the state after this event. key = (event.type, event.state_key) if key in context.prev_state_ids: @@ -258,58 +275,176 @@ class StateHandler(object): context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids - if context.delta_ids is not None: - context.delta_ids = dict(context.delta_ids) + if entry.state_group: + # If the state at the event has a state group assigned then + # we can use that as the prev group + context.prev_group = entry.state_group + context.delta_ids = { + key: event.event_id + } + elif entry.prev_group: + # If the state at the event only has a prev group, then we can + # use that as a prev group too. + context.prev_group = entry.prev_group + context.delta_ids = dict(entry.delta_ids) context.delta_ids[key] = event.event_id + + context.state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=context.prev_group, + delta_ids=context.delta_ids, + current_state_ids=context.current_state_ids, + ) else: + context.current_state_ids = context.prev_state_ids + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + if entry.state_group is None: - entry.state_group = self.store.get_next_state_group() + entry.state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=entry.prev_group, + delta_ids=entry.delta_ids, + current_state_ids=context.current_state_ids, + ) entry.state_id = entry.state_group context.state_group = entry.state_group - context.current_state_ids = context.prev_state_ids - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks - @log_function - def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): + def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. + Args: + room_id (str): + event_ids (list[str]): + Returns: - a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Deferred[_StateCacheEntry]: resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) + # map from state group id to the state in that state group (where + # 'state' is a map from state key to event id) + # dict[int, dict[(str, str), str]] state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) - - group_names = frozenset(state_groups_ids.keys()) - if len(group_names) == 1: + if len(state_groups_ids) == 1: name, state_list = state_groups_ids.items().pop() + prev_group, delta_ids = yield self.store.get_state_group_delta(name) + defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, - prev_group=name, - delta_ids={}, + prev_group=prev_group, + delta_ids=delta_ids, )) + result = yield self._state_resolution_handler.resolve_state_groups( + room_id, state_groups_ids, None, self._state_map_factory, + ) + defer.returnValue(result) + + def _state_map_factory(self, ev_ids): + return self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + + def resolve_events(self, state_sets, event): + logger.info( + "Resolving state for %s with %d groups", event.room_id, len(state_sets) + ) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } + + with Measure(self.clock, "state._resolve_events"): + new_state = resolve_events_with_state_map(state_set_ids, state_map) + + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.iteritems() + } + + return new_state + + +class StateResolutionHandler(object): + """Responsible for doing state conflict resolution. + + Note that the storage layer depends on this handler, so all functions must + be storage-independent. + """ + def __init__(self, hs): + self.clock = hs.get_clock() + + # dict of set of event_ids -> _StateCacheEntry. + self._state_cache = None + self.resolve_linearizer = Linearizer(name="state_resolve_lock") + + def start_caching(self): + logger.debug("start_caching") + + self._state_cache = ExpiringCache( + cache_name="state_cache", + clock=self.clock, + max_len=SIZE_OF_CACHE, + expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, + reset_expiry_on_get=True, + ) + + self._state_cache.start() + + @defer.inlineCallbacks + @log_function + def resolve_state_groups( + self, room_id, state_groups_ids, event_map, state_map_factory, + ): + """Resolves conflicts between a set of state groups + + Always generates a new state group (unless we hit the cache), so should + not be called for a single state group + + Args: + room_id (str): room we are resolving for (used for logging) + state_groups_ids (dict[int, dict[(str, str), str]]): + map from state group id to the state in that state group + (where 'state' is a map from state key to event id) + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + Returns: + Deferred[_StateCacheEntry]: resolved state + """ + logger.debug( + "resolve_state_groups state_groups %s", + state_groups_ids.keys() + ) + + group_names = frozenset(state_groups_ids.keys()) + with (yield self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) @@ -320,56 +455,62 @@ class StateHandler(object): "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) + # build a map from state key to the event_ids which set that state. + # dict[(str, str), set[str]) state = {} - for st in state_groups_ids.values(): - for key, e_id in st.items(): + for st in state_groups_ids.itervalues(): + for key, e_id in st.iteritems(): state.setdefault(key, set()).add(e_id) + # build a map from state key to the event_ids which set that state, + # including only those where there are state keys in conflict. conflicted_state = { k: list(v) - for k, v in state.items() + for k, v in state.iteritems() if len(v) > 1 } if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events( + new_state = yield resolve_events_with_factory( state_groups_ids.values(), - state_map_factory=lambda ev_ids: self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), + event_map=event_map, + state_map_factory=state_map_factory, ) else: new_state = { - key: e_ids.pop() for key, e_ids in state.items() + key: e_ids.pop() for key, e_ids in state.iteritems() } - state_group = None - new_state_event_ids = frozenset(new_state.values()) - for sg, events in state_groups_ids.items(): - if new_state_event_ids == frozenset(e_id for e_id in events): - state_group = sg - break - if state_group is None: - # Worker instances don't have access to this method, but we want - # to set the state_group on the main instance to increase cache - # hits. - if hasattr(self.store, "get_next_state_group"): - state_group = self.store.get_next_state_group() - - prev_group = None - delta_ids = None - for old_group, old_ids in state_groups_ids.items(): - if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): - n_delta_ids = { - k: v - for k, v in new_state.items() - if old_ids.get(k) != v - } - if not delta_ids or len(n_delta_ids) < len(delta_ids): - prev_group = old_group - delta_ids = n_delta_ids + with Measure(self.clock, "state.create_group_ids"): + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. + state_group = None + new_state_event_ids = frozenset(new_state.itervalues()) + for sg, events in state_groups_ids.iteritems(): + if new_state_event_ids == frozenset(e_id for e_id in events): + state_group = sg + break + + # TODO: We want to create a state group for this set of events, to + # increase cache hits, but we need to make sure that it doesn't + # end up as a prev_group without being added to the database + + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.iteritems(): + if not set(new_state) - set(old_ids): + n_delta_ids = { + k: v + for k, v in new_state.iteritems() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids cache = _StateCacheEntry( state=new_state, @@ -383,30 +524,6 @@ class StateHandler(object): defer.returnValue(cache) - def resolve_events(self, state_sets, event): - logger.info( - "Resolving state for %s with %d groups", event.room_id, len(state_sets) - ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } - - with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events(state_set_ids, state_map) - - new_state = { - key: state_map[ev_id] for key, ev_id in new_state.items() - } - - return new_state - def _ordered_events(events): def key_func(e): @@ -415,19 +532,17 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, state_map_factory): +def resolve_events_with_state_map(state_sets, state_map): """ Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - state_map_factory(dict|callable): If callable, then will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. Otherwise, should be - a dict from event_id to event of all events in state_sets. + state_map(dict): a dict from event_id to event, for all events in + state_sets. Returns - dict[(str, str), synapse.events.FrozenEvent] is a map from - (type, state_key) to event. + dict[(str, str), str]: + a map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -436,13 +551,6 @@ def resolve_events(state_sets, state_map_factory): state_sets, ) - if callable(state_map_factory): - return _resolve_with_state_fac( - unconflicted_state, conflicted_state, state_map_factory - ) - - state_map = state_map_factory - auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -456,6 +564,21 @@ def _seperate(state_sets): """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. + + Args: + state_sets(list[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ unconflicted_state = dict(state_sets[0]) conflicted_state = {} @@ -486,24 +609,63 @@ def _seperate(state_sets): @defer.inlineCallbacks -def _resolve_with_state_fac(unconflicted_state, conflicted_state, - state_map_factory): +def resolve_events_with_factory(state_sets, event_map, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + needed_events = set( event_id for event_ids in conflicted_state.itervalues() for event_id in event_ids ) + if event_map is not None: + needed_events -= set(event_map.iterkeys()) logger.info("Asking for %d conflicted events", len(needed_events)) + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) new_needed_events = set(auth_events.itervalues()) new_needed_events -= needed_events + if event_map is not None: + new_needed_events -= set(event_map.iterkeys()) logger.info("Asking for %d auth events", len(new_needed_events)) @@ -541,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ auth_events = { key: state_map[ev_id] - for key, ev_id in auth_event_ids.items() + for key, ev_id in auth_event_ids.iteritems() if ev_id in state_map } @@ -549,7 +711,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ resolved_state = _resolve_state_events( conflicted_state, auth_events ) - except: + except Exception: logger.exception("Failed to resolve state") raise @@ -579,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = _resolve_auth_events( @@ -589,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = _resolve_auth_events( @@ -599,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = _resolve_normal_events( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d604e7668f..8cdfd50f90 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) -from ._base import LoggingTransaction from .directory import DirectoryStore from .events import EventsStore from .presence import PresenceStore, UserPresenceState @@ -37,7 +35,7 @@ from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore from .deviceinbox import DeviceInboxStore - +from .group_server import GroupServerStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore @@ -49,6 +47,7 @@ from .tags import TagsStore from .account_data import AccountDataStore from .openid import OpenIdStore from .client_ips import ClientIpStore +from .user_directory import UserDirectoryStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .engines import PostgresEngine @@ -86,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore, ClientIpStore, DeviceStore, DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, ): def __init__(self, db_conn, hs): @@ -101,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -121,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") @@ -133,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id", + ) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( @@ -141,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore, else: self._cache_id_gen = None - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - - account_max = self._account_data_id_gen.get_current_token() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, - ) - self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( @@ -175,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", @@ -221,23 +185,35 @@ class DataStore(RoomMemberStore, RoomStore, "DeviceListFederationStreamChangeCache", device_list_max, ) - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - after_callbacks=[] + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 60 * 60 * 1000 + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", min_group_updates_id, + prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - super(DataStore, self).__init__(hs) + super(DataStore, self).__init__(db_conn, hs) def take_presence_startup_info(self): active_on_startup = self._presence_on_startup @@ -266,36 +242,110 @@ class DataStore(RoomMemberStore, RoomStore, return [UserPresenceState(**row) for row in rows] - @defer.inlineCallbacks def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. """ def _count_users(txn): - txn.execute( - "SELECT COUNT(DISTINCT user_id) AS users" - " FROM user_ips" - " WHERE last_seen > ?", - # This is close enough to a day for our purposes. - (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) - ) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ - def get_user_ip_and_agents(self, user): - return self._simple_select_list( - table="user_ips", - keyvalues={"user_id": user.to_string()}, - retcols=[ - "access_token", "ip", "user_agent", "last_seen" - ], - desc="get_user_ip_and_agents", - ) + txn.execute(sql, (yesterday,)) + count, = txn.fetchone() + return count + + return self.runInteraction("count_users", _count_users) + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + for row in txn: + if row[0] is 'unknown': + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + count, = txn.fetchone() + results['all'] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) def get_users(self): """Function to reterive a list of users in users table. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b0dc391190..2262776ab2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,9 +16,7 @@ import logging from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache -from synapse.util.caches import intern_dict from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -28,10 +26,6 @@ from twisted.internet import defer import sys import time import threading -import os - - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) logger = logging.getLogger(__name__) @@ -53,20 +47,27 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name", "database_engine", "after_callbacks"] + __slots__ = [ + "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", + ] - def __init__(self, txn, name, database_engine, after_callbacks): + def __init__(self, txn, name, database_engine, after_callbacks, + exception_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) - def call_after(self, callback, *args): + def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ - self.after_callbacks.append((callback, args)) + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -74,13 +75,22 @@ class LoggingTransaction(object): def __setattr__(self, name, value): setattr(self.txn, name, value) + def __iter__(self): + return self.txn.__iter__() + def execute(self, sql, *args): self._do_execute(self.txn.execute, sql, *args) def executemany(self, sql, *args): self._do_execute(self.txn.executemany, sql, *args) + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) @@ -91,7 +101,7 @@ class LoggingTransaction(object): "[SQL values] {%s} %r", self.name, args[0] ) - except: + except Exception: # Don't let logging failures stop SQL from working pass @@ -127,7 +137,7 @@ class PerformanceCounters(object): def interval(self, interval_duration, limit=3): counters = [] - for name, (count, cum_time) in self.current_counters.items(): + for name, (count, cum_time) in self.current_counters.iteritems(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append(( (cum_time - prev_time) / interval_duration, @@ -150,7 +160,7 @@ class PerformanceCounters(object): class SQLBaseStore(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() @@ -168,10 +178,6 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR - ) - self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -209,8 +215,8 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, logging_context, - func, *args, **kwargs): + def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, + logging_context, func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID @@ -229,7 +235,8 @@ class SQLBaseStore(object): try: txn = conn.cursor() txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks + txn, name, self.database_engine, after_callbacks, + exception_callbacks, ) r = func(txn, *args, **kwargs) conn.commit() @@ -284,47 +291,66 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() + """Starts a transaction on the database and runs a given function - start_time = time.time() * 1000 + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + current_context = LoggingContext.current_context() after_callbacks = [] + exception_callbacks = [] def inner_func(conn, *args, **kwargs): - with LoggingContext("runInteraction") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + return self._new_transaction( + conn, desc, after_callbacks, exception_callbacks, current_context, + func, *args, **kwargs + ) - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() + try: + result = yield self.runWithConnection(inner_func, *args, **kwargs) - current_context.copy_to(context) - return self._new_transaction( - conn, desc, after_callbacks, current_context, - func, *args, **kwargs - ) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise - try: - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - finally: - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ current_context = LoggingContext.current_context() start_time = time.time() * 1000 def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + sched_duration_ms = time.time() * 1000 - start_time + sql_scheduling_timer.inc_by(sched_duration_ms) + current_context.add_database_scheduled(sched_duration_ms) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") @@ -350,9 +376,9 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(column[0] for column in cursor.description) + col_headers = list(intern(str(column[0])) for column in cursor.description) results = list( - intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() + dict(zip(col_headers, row)) for row in cursor ) return results @@ -417,6 +443,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) + def _simple_insert_many(self, table, values, desc): + return self.runInteraction( + desc, self._simple_insert_many_txn, table, values + ) + @staticmethod def _simple_insert_many_txn(txn, table, values): if not values: @@ -452,23 +483,53 @@ class SQLBaseStore(object): txn.executemany(sql, vals) + @defer.inlineCallbacks def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values - insertion_values (dict): key/values to use when inserting + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. Returns: Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self.runInteraction( - desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock - ) + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self._simple_upsert_txn, table, keyvalues, values, insertion_values, + lock=lock + ) + defer.returnValue(result) + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warn( + "IntegrityError when upserting into %s; retrying: %s", + table, e + ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): @@ -476,45 +537,38 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) - # Try to update + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), " AND ".join("%s = ?" % (k,) for k in keyvalues) ) sqlargs = values.values() + keyvalues.values() - logger.debug( - "[SQL] %s Args=%s", - sql, sqlargs, - ) txn.execute(sql, sqlargs) - if txn.rowcount == 0: - # We didn't update and rows so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) - ) - logger.debug( - "[SQL] %s Args=%s", - sql, keyvalues.values(), - ) - txn.execute(sql, allvalues.values()) + # We didn't update any rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) - return True - else: - return False + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + txn.execute(sql, allvalues.values()) + # successfully inserted + return True def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name @@ -567,22 +621,20 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) - else: - where = "" - sql = ( - "SELECT %(retcol)s FROM %(table)s %(where)s" + "SELECT %(retcol)s FROM %(table)s" ) % { "retcol": retcol, "table": table, - "where": where, } - txn.execute(sql, keyvalues.values()) + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + txn.execute(sql, keyvalues.values()) + else: + txn.execute(sql) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] def _simple_select_onecol(self, table, keyvalues, retcol, desc="_simple_select_onecol"): @@ -591,7 +643,7 @@ class SQLBaseStore(object): Args: table (str): table name - keyvalues (dict): column names and values to select the rows with + keyvalues (dict|None): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: @@ -715,7 +767,7 @@ class SQLBaseStore(object): ) values.extend(iterable) - for key, value in keyvalues.items(): + for key, value in keyvalues.iteritems(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -728,6 +780,33 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) + def _simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, + self._simple_update_txn, + table, keyvalues, updatevalues, + ) + + @staticmethod + def _simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + return txn.rowcount + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for @@ -753,27 +832,13 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - @staticmethod - def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute( - update_sql, - updatevalues.values() + keyvalues.values() - ) + @classmethod + def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - if txn.rowcount == 0: + if rowcount == 0: raise StoreError(404, "No row found") - if txn.rowcount > 1: + if rowcount > 1: raise StoreError(500, "More than one row matched") @staticmethod @@ -843,6 +908,47 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) + def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + """ + if not iterable: + return + + sql = "DELETE FROM %s" % table + + clauses = [] + values = [] + clauses.append( + "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) + ) + values.extend(iterable) + + for key, value in keyvalues.iteritems(): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % ( + sql, + " AND ".join(clauses), + ) + return txn.execute(sql, values) + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -863,16 +969,16 @@ class SQLBaseStore(object): txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() cache = { row[0]: int(row[1]) - for row in rows + for row in txn } + txn.close() + if cache: - min_val = min(cache.values()) + min_val = min(cache.itervalues()) else: min_val = max_value @@ -895,6 +1001,7 @@ class SQLBaseStore(object): # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 3fa226e92d..f83ff0454a 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator + +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks -import ujson as json +import abc +import simplejson as json import logging logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: @@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore): for row in rows }) + @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", + get_account_data_for_room_and_type_txn, + ) + def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server @@ -182,7 +244,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) global_account_data = { - row[0]: json.loads(row[1]) for row in txn.fetchall() + row[0]: json.loads(row[1]) for row in txn } sql = ( @@ -193,7 +255,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} - for row in txn.fetchall(): + for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) @@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -222,9 +314,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,18 +329,23 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type,), content, ) result = self._account_data_id_gen.get_current_token() @@ -263,9 +363,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,37 +377,43 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 514570561f..12ea8a158c 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,39 +14,58 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re import simplejson as json from twisted.internet import defer -from synapse.api.constants import Membership from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.roommember import RoomsForUser +from synapse.storage.events import EventsWorkerStore from ._base import SQLBaseStore logger = logging.getLogger(__name__) -class ApplicationServiceStore(SQLBaseStore): +def _make_exclusive_regex(services_cache): + # We precompie a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in services_cache + for regex in service.get_exlusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + exclusive_user_regex = None - def __init__(self, hs): - super(ApplicationServiceStore, self).__init__(hs) - self.hostname = hs.hostname + return exclusive_user_regex + + +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) + self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + + super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) def get_app_services(self): return self.services_cache def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service + """Check if the user is one associated with an app service (exclusively) """ - for service in self.services_cache: - if service.is_interested_in_user(user_id): - return True - return False + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. @@ -78,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore): return service return None - def get_app_service_rooms(self, service): - """Get a list of RoomsForUser for this application service. - - Application services may be "interested" in lots of rooms depending on - the room ID, the room aliases, or the members in the room. This function - takes all of these into account and returns a list of RoomsForUser which - represent the entire list of room IDs that this application service - wants to know about. + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. Args: - service: The application service to get a room list for. + as_id (str): The application service ID. Returns: - A list of RoomsForUser. + synapse.appservice.ApplicationService or None. """ - return self.runInteraction( - "get_app_service_rooms", - self._get_app_service_rooms_txn, - service, - ) - - def _get_app_service_rooms_txn(self, txn, service): - # get all rooms matching the room ID regex. - room_entries = self._simple_select_list_txn( - txn=txn, table="rooms", keyvalues=None, retcols=["room_id"] - ) - matching_room_list = set([ - r["room_id"] for r in room_entries if - service.is_interested_in_room(r["room_id"]) - ]) - - # resolve room IDs for matching room alias regex. - room_alias_mappings = self._simple_select_list_txn( - txn=txn, table="room_aliases", keyvalues=None, - retcols=["room_id", "room_alias"] - ) - matching_room_list |= set([ - r["room_id"] for r in room_alias_mappings if - service.is_interested_in_alias(r["room_alias"]) - ]) - - # get all rooms for every user for this AS. This is scoped to users on - # this HS only. - user_list = self._simple_select_list_txn( - txn=txn, table="users", keyvalues=None, retcols=["name"] - ) - user_list = [ - u["name"] for u in user_list if - service.is_interested_in_user(u["name"]) - ] - rooms_for_user_matching_user_id = set() # RoomsForUser list - for user_id in user_list: - # FIXME: This assumes this store is linked with RoomMemberStore :( - rooms_for_user = self._get_rooms_for_user_where_membership_is_txn( - txn=txn, - user_id=user_id, - membership_list=[Membership.JOIN] - ) - rooms_for_user_matching_user_id |= set(rooms_for_user) - - # make RoomsForUser tuples for room ids and aliases which are not in the - # main rooms_for_user_list - e.g. they are rooms which do not have AS - # registered users in it. - known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id] - missing_rooms_for_user = [ - RoomsForUser(r, service.sender, "join") for r in - matching_room_list if r not in known_room_ids - ] - rooms_for_user_matching_user_id |= set(missing_rooms_for_user) - - return rooms_for_user_matching_user_id + for service in self.services_cache: + if service.id == as_id: + return service + return None -class ApplicationServiceTransactionStore(SQLBaseStore): +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass - def __init__(self, hs): - super(ApplicationServiceTransactionStore, self).__init__(hs) +class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, + EventsWorkerStore): @defer.inlineCallbacks def get_appservices_by_state(self, state): """Get a list of application services based on their state. @@ -399,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore): events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 94b2bcc54a..8af325a9f5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import synapse.util.async from ._base import SQLBaseStore from . import engines from twisted.internet import defer -import ujson as json +import simplejson as json import logging logger = logging.getLogger(__name__) @@ -79,35 +80,26 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs): - super(BackgroundUpdateStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(BackgroundUpdateStore, self).__init__(db_conn, hs) self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} - self._background_update_timer = None + self._all_done = False @defer.inlineCallbacks def start_doing_background_updates(self): - assert self._background_update_timer is None, \ - "background updates already running" - logger.info("Starting background schema updates") while True: - sleep = defer.Deferred() - self._background_update_timer = self._clock.call_later( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None - ) - try: - yield sleep - finally: - self._background_update_timer = None + yield synapse.util.async.sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) try: result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) - except: + except Exception: logger.exception("Error doing update") else: if result is None: @@ -115,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) + self._all_done = True defer.returnValue(None) @defer.inlineCallbacks + def has_completed_background_updates(self): + """Check if all the background updates have completed + + Returns: + Deferred[bool]: True if all background updates have completed + """ + # if we've previously determined that there is nothing left to do, that + # is easy + if self._all_done: + defer.returnValue(True) + + # obviously, if we have things in our queue, we're not done. + if self._background_update_queue: + defer.returnValue(False) + + # otherwise, check if there are updates to be run. This is important, + # as we may be running on a worker which doesn't perform the bg updates + # itself, but still wants to wait for them to happen. + updates = yield self._simple_select_onecol( + "background_updates", + keyvalues=None, + retcol="1", + desc="check_background_updates", + ) + if not updates: + self._all_done = True + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): """Does some amount of work on the next queued background update @@ -218,8 +242,29 @@ class BackgroundUpdateStore(SQLBaseStore): """ self._background_update_handlers[update_name] = update_handler + def register_noop_background_update(self, update_name): + """Register a noop handler for a background update. + + This is useful when we previously did a background update, but no + longer wish to do the update. In this case the background update should + be removed from the schema delta files, but there may still be some + users who have the background update queued, so this method should + also be called to clear the update. + + Args: + update_name (str): Name of update + """ + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update(update_name) + defer.returnValue(1) + + self.register_background_update_handler(update_name, noop_update) + def register_background_index_update(self, update_name, index_name, - table, columns, where_clause=None): + table, columns, where_clause=None, + unique=False, + psql_only=False): """Helper for store classes to do a background index addition To use: @@ -235,48 +280,80 @@ class BackgroundUpdateStore(SQLBaseStore): index_name (str): name of index to add table (str): table to add index to columns (list[str]): columns/expressions to include in index + unique (bool): true to make a UNIQUE index + psql_only: true to only create this index on psql databases (useful + for virtual sqlite tables) """ - # if this is postgres, we add the indexes concurrently. Otherwise - # we fall back to doing it inline - if isinstance(self.database_engine, engines.PostgresEngine): - conc = True - else: - conc = False - # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 - where_clause = None - - sql = ( - "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" - " %(where_clause)s" - ) % { - "conc": "CONCURRENTLY" if conc else "", - "name": index_name, - "table": table, - "columns": ", ".join(columns), - "where_clause": "WHERE " + where_clause if where_clause else "" - } - - def create_index_concurrently(conn): + def create_index_psql(conn): conn.rollback() # postgres insists on autocommit for the index conn.set_session(autocommit=True) - c = conn.cursor() - c.execute(sql) - conn.set_session(autocommit=False) - def create_index(conn): + try: + c = conn.cursor() + + # If a previous attempt to create the index was interrupted, + # we may already have a half-built index. Let's just drop it + # before trying to create it again. + + sql = "DROP INDEX IF EXISTS %s" % (index_name,) + logger.debug("[SQL] %s", sql) + c.execute(sql) + + sql = ( + "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" + " ON %(table)s" + " (%(columns)s) %(where_clause)s" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } + logger.debug("[SQL] %s", sql) + c.execute(sql) + finally: + conn.set_session(autocommit=False) + + def create_index_sqlite(conn): + # Sqlite doesn't support concurrent creation of indexes. + # + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy and CentOS 7 have 3.7 + # + # We assume that sqlite doesn't give us invalid indices; however + # we may still end up with the index existing but the + # background_updates not having been recorded if synapse got shut + # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite + # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.) + sql = ( + "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s" + " (%(columns)s)" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + } + c = conn.cursor() + logger.debug("[SQL] %s", sql) c.execute(sql) + if isinstance(self.database_engine, engines.PostgresEngine): + runner = create_index_psql + elif psql_only: + runner = None + else: + runner = create_index_sqlite + @defer.inlineCallbacks def updater(progress, batch_size): - logger.info("Adding index %s to %s", index_name, table) - if conc: - yield self.runWithConnection(create_index_concurrently) - else: - yield self.runWithConnection(create_index) + if runner is not None: + logger.info("Adding index %s to %s", index_name, table) + yield self.runWithConnection(runner) yield self._end_background_update(update_name) defer.returnValue(1) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 71e5ea112f..7b44dae0fc 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,11 +15,14 @@ import logging -from twisted.internet import defer +from twisted.internet import defer, reactor from ._base import Cache from . import background_updates +from synapse.util.caches import CACHE_SIZE_FACTOR + + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,13 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): + def __init__(self, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, + max_entries=50000 * CACHE_SIZE_FACTOR, ) - super(ClientIpStore, self).__init__(hs) + super(ClientIpStore, self).__init__(db_conn, hs) self.register_background_index_update( "user_ips_device_index", @@ -44,10 +48,26 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id", "last_seen"], ) - @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent, device_id): - now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, ip) + self.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + + # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) + + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, + now=None): + if not now: + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -56,34 +76,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - defer.returnValue(None) + return self.client_ip_last_seen.prefill(key, now) - # It's safe not to lock here: a) no unique constraint, - # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely - yield self._simple_upsert( - "user_ips", - keyvalues={ - "user_id": user.to_string(), - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": now, - }, - desc="insert_client_ip", - lock=False, + self._batch_row_update[key] = (user_agent, device_id, now) + + def _update_client_ips_batch(self): + to_update = self._batch_row_update + self._batch_row_update = {} + return self.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) + def _update_client_ips_batch_txn(self, txn, to_update): + self.database_engine.lock_table(txn, "user_ips") + + for entry in to_update.iteritems(): + (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry + + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + @defer.inlineCallbacks - def get_last_client_ip_by_device(self, devices): + def get_last_client_ip_by_device(self, user_id, device_id): """For each device_id listed, give the user_ip it was last seen on Args: - devices (iterable[(str, str)]): list of (user_id, device_id) pairs + user_id (str) + device_id (str): If None fetches all devices for the user Returns: defer.Deferred: resolves to a dict, where the keys @@ -94,6 +128,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): res = yield self.runInteraction( "get_last_client_ip_by_device", self._get_last_client_ip_by_device_txn, + user_id, device_id, retcols=( "user_id", "access_token", @@ -102,23 +137,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "device_id", "last_seen", ), - devices=devices ) ret = {(d["user_id"], d["device_id"]): d for d in res} + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, did, last_seen = self._batch_row_update[key] + if not device_id or did == device_id: + ret[(user_id, device_id)] = { + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": did, + "last_seen": last_seen, + } defer.returnValue(ret) @classmethod - def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): where_clauses = [] bindings = [] - for (user_id, device_id) in devices: - if device_id is None: - where_clauses.append("(user_id = ? AND device_id IS NULL)") - bindings.extend((user_id, )) - else: - where_clauses.append("(user_id = ? AND device_id = ?)") - bindings.extend((user_id, device_id)) + if device_id is None: + where_clauses.append("user_id = ?") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + if not where_clauses: + return [] inner_select = ( "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " @@ -143,3 +192,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): txn.execute(sql, bindings) return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_ip_and_agents(self, user): + user_id = user.to_string() + results = {} + + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, _, last_seen = self._batch_row_update[key] + results[(access_token, ip)] = (user_agent, last_seen) + + rows = yield self._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", "ip", "user_agent", "last_seen" + ], + desc="get_user_ip_and_agents", + ) + + results.update( + ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) + for row in rows + ) + defer.returnValue(list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in results.iteritems() + )) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index bde3b5cbbc..a879e5bfc1 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -14,12 +14,14 @@ # limitations under the License. import logging -import ujson +import simplejson from twisted.internet import defer from .background_updates import BackgroundUpdateStore +from synapse.util.caches.expiringcache import ExpiringCache + logger = logging.getLogger(__name__) @@ -27,8 +29,8 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, hs): - super(DeviceInboxStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_inbox_stream_index", @@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore): self._background_drop_index_device_inbox, ) + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, remote_messages_by_destination): @@ -74,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore): ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = ujson.dumps(edu) + edu_json = simplejson.dumps(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -166,8 +177,8 @@ class DeviceInboxStore(BackgroundUpdateStore): " WHERE user_id = ?" ) txn.execute(sql, (user_id,)) - message_json = ujson.dumps(messages_by_device["*"]) - for row in txn.fetchall(): + message_json = simplejson.dumps(messages_by_device["*"]) + for row in txn: # Add the message for all devices for this user on this # server. device = row[0] @@ -184,11 +195,11 @@ class DeviceInboxStore(BackgroundUpdateStore): # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. txn.execute(sql, [user_id] + devices) - for row in txn.fetchall(): + for row in txn: # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = ujson.dumps(messages_by_device[device]) + message_json = simplejson.dumps(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: @@ -240,9 +251,9 @@ class DeviceInboxStore(BackgroundUpdateStore): user_id, device_id, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) @@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_new_messages_for_device", get_new_messages_for_device_txn, ) + @defer.inlineCallbacks def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): """ Args: @@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore): Returns: A deferred that resolves to the number of messages deleted. """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + def delete_messages_for_device_txn(txn): sql = ( "DELETE FROM device_inbox" @@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - return self.runInteraction( + count = yield self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: @@ -291,22 +325,25 @@ class DeviceInboxStore(BackgroundUpdateStore): # we return. upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id, user_id" + "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY user_id" ) txn.execute(sql, (last_pos, upper_pos)) rows = txn.fetchall() sql = ( - "SELECT stream_id, destination" + "SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY destination" ) txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn.fetchall()) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() return rows @@ -323,12 +360,12 @@ class DeviceInboxStore(BackgroundUpdateStore): """ Args: destination(str): The name of the remote server. - last_stream_id(int): The last position of the device message stream + last_stream_id(int|long): The last position of the device message stream that the server sent up to. - current_stream_id(int): The current position of the device + current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int): List of messages for the device and where + Deferred ([dict], int|long): List of messages for the device and where in the stream the messages got to. """ @@ -350,9 +387,9 @@ class DeviceInboxStore(BackgroundUpdateStore): destination, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8e17800364..712106b83a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,24 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import ujson as json +import simplejson as json from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore +from ._base import SQLBaseStore, Cache +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): - def __init__(self, hs): - super(DeviceStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", + keylen=2, + max_entries=10000, + ) self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -45,6 +62,10 @@ class DeviceStore(SQLBaseStore): defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + defer.returnValue(False) + try: inserted = yield self._simple_insert( "devices", @@ -56,6 +77,7 @@ class DeviceStore(SQLBaseStore): desc="store_device", or_ignore=True, ) + self.device_id_exists_cache.prefill(key, True) defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" @@ -84,6 +106,7 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + @defer.inlineCallbacks def delete_device(self, user_id, device_id): """Delete a device. @@ -93,12 +116,34 @@ class DeviceStore(SQLBaseStore): Returns: defer.Deferred """ - return self._simple_delete_one( + yield self._simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, desc="delete_device", ) + self.device_id_exists_cache.invalidate((user_id, device_id)) + + @defer.inlineCallbacks + def delete_devices(self, user_id, device_ids): + """Deletes several devices. + + Args: + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id}, + desc="delete_devices", + ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) + def update_device(self, user_id, device_id, new_display_name=None): """Update a device. @@ -144,6 +189,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + @cached(max_entries=10000) def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. @@ -156,16 +202,36 @@ class DeviceStore(SQLBaseStore): allow_none=True, ) + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_user_devices_from_cache", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + @defer.inlineCallbacks def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - return self._simple_delete( + yield self._simple_delete( table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, }, desc="mark_remote_user_device_list_as_unsubscribed", ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): @@ -191,6 +257,12 @@ class DeviceStore(SQLBaseStore): } ) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -234,6 +306,12 @@ class DeviceStore(SQLBaseStore): ] ) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -249,7 +327,7 @@ class DeviceStore(SQLBaseStore): """Get stream of updates to send to remote servers Returns: - (now_stream_id, [ { updates }, .. ]) + (int, list[dict]): current stream id and list of updates """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -270,24 +348,27 @@ class DeviceStore(SQLBaseStore): SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? GROUP BY user_id, device_id + LIMIT 20 """ txn.execute( sql, (destination, from_stream_id, now_stream_id, False) ) - rows = txn.fetchall() - if not rows: + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in txn} + if not query_map: return (now_stream_id, []) - # maps (user_id, device_id) -> stream_id - query_map = {(r[0], r[1]): r[2] for r in rows} + if len(query_map) >= 20: + now_stream_id = max(stream_id for stream_id in query_map.itervalues()) + devices = self._get_e2e_device_keys_txn( txn, query_map.keys(), include_all_devices=True ) prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes + FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? AND stream_id <= ? """ @@ -320,6 +401,7 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + @defer.inlineCallbacks def get_user_devices_from_cache(self, query_list): """Get the devices (and keys if any) for remote users from the cache. @@ -332,27 +414,11 @@ class DeviceStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - return self.runInteraction( - "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, - query_list, + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id ) - - def _get_user_devices_from_cache_txn(self, txn, query_list): - user_ids = {user_id for user_id, _ in query_list} - - user_ids_in_cache = set() - for user_id in user_ids: - stream_ids = self._simple_select_onecol_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - retcol="stream_id", - ) - if stream_ids: - user_ids_in_cache.add(user_id) - user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -361,32 +427,40 @@ class DeviceStore(SQLBaseStore): continue if device_id: - content = self._simple_select_one_onecol_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="content", - ) - results.setdefault(user_id, {})[device_id] = json.loads(content) + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device else: - devices = self._simple_select_list_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - ) - results[user_id] = { - device["device_id"]: json.loads(device["content"]) - for device in devices - } - user_ids_in_cache.discard(user_id) + results[user_id] = yield self._get_cached_devices_for_user(user_id) - return user_ids_not_in_cache, results + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(json.loads(content)) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: json.loads(device["content"]) + for device in devices + }) def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user @@ -436,32 +510,43 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # First we DELETE all rows such that only the latest row for each - # (destination, user_id is left. We do this by selecting first and - # deleting. + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. sql = """ - SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id - HAVING count(*) > 1 """ txn.execute(sql, (destination, stream_id,)) rows = txn.fetchall() sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND stream_id < ? + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? + """ + txn.executemany( + sql, ((row[1], destination, row[0],) for row in rows if row[2]) + ) + + sql = """ + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) """ txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows) + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) ) - # Mark everything that is left as sent + # Delete all sent outbound pokes sql = """ - UPDATE device_lists_outbound_pokes SET sent = ? + DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (True, destination, stream_id,)) + txn.execute(sql, (destination, stream_id,)) @defer.inlineCallbacks def get_user_whose_devices_changed(self, from_key): @@ -473,12 +558,12 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(changed)) sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? """ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row[0] for row in rows)) - def get_all_device_list_changes_for_remotes(self, from_key): + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. @@ -486,11 +571,11 @@ class DeviceStore(SQLBaseStore): sql = """ SELECT stream_id, user_id, destination FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE stream_id > ? + WHERE ? < stream_id AND stream_id <= ? """ return self._execute( - "get_users_and_hosts_device_list", None, - sql, from_key, + "get_all_device_list_changes_for_remotes", None, + sql, from_key, to_key ) @defer.inlineCallbacks @@ -518,6 +603,16 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, stream_id) for device_id in device_ids] + ) + self._simple_insert_many_txn( txn, table="device_lists_stream", @@ -586,6 +681,14 @@ class DeviceStore(SQLBaseStore): ) ) + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) return self.runInteraction( diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 9caaf81f2c..d0c0059757 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple( ) -class DirectoryStore(SQLBaseStore): - +class DirectoryWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): """ Get's the room_id and server list for a given room_alias @@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore): RoomAliasMapping(room_id, room_alias.to_string(), servers) ) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers @@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore): ) defer.returnValue(ret) - def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( - table="room_aliases", - keyvalues={ - "room_alias": room_alias, - }, - retcol="creator", - desc="get_room_alias_creator", - allow_none=True - ) - @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( @@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): @@ -160,13 +169,22 @@ class DirectoryStore(SQLBaseStore): (room_alias.to_string(),) ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + return room_id - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", + def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def _update_aliases_for_room_txn(txn): + sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" + txn.execute(sql, (new_room_id, creator, old_room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (old_room_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (new_room_id,) + ) + return self.runInteraction( + "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index b9f1365f92..ff8538ddf8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,8 +14,10 @@ # limitations under the License. from twisted.internet import defer +from synapse.util.caches.descriptors import cached + from canonicaljson import encode_canonical_json -import ujson as json +import simplejson as json from ._base import SQLBaseStore @@ -120,26 +122,77 @@ class EndToEndKeyStore(SQLBaseStore): return result - def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): + @defer.inlineCallbacks + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key + """ + + rows = yield self._simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json",), + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + desc="add_e2e_one_time_keys_check", + ) + + defer.returnValue({ + (row["algorithm"], row["key_id"]): row["key_json"] for row in rows + }) + + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ + def _add_e2e_one_time_keys(txn): - for (algorithm, key_id, json_bytes) in key_list: - self._simple_upsert_txn( - txn, table="e2e_one_time_keys_json", - keyvalues={ + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self._simple_insert_many_txn( + txn, table="e2e_one_time_keys_json", + values=[ + { "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, - }, - values={ "ts_added_ms": time_now, "key_json": json_bytes, } - ) - return self.runInteraction( - "add_e2e_one_time_keys", _add_e2e_one_time_keys + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + yield self.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) + @cached(max_entries=10000) def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: @@ -153,7 +206,7 @@ class EndToEndKeyStore(SQLBaseStore): ) txn.execute(sql, (user_id, device_id)) result = {} - for algorithm, key_count in txn.fetchall(): + for algorithm, key_count in txn: result[algorithm] = key_count return result return self.runInteraction( @@ -174,7 +227,7 @@ class EndToEndKeyStore(SQLBaseStore): user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn.fetchall(): + for key_id, key_json in txn: device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) sql = ( @@ -184,20 +237,29 @@ class EndToEndKeyStore(SQLBaseStore): ) for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) return result return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): - yield self._simple_delete( - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_device_keys_by_device" - ) - yield self._simple_delete( - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_one_time_keys_by_device" + def delete_e2e_keys_by_device_txn(txn): + self._simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + return self.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 338b495611..8c868ece75 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,6 +18,7 @@ from .postgres import PostgresEngine from .sqlite3 import Sqlite3Engine import importlib +import platform SUPPORTED_MODULE = { @@ -31,6 +32,10 @@ def create_engine(database_config): engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: + # pypy requires psycopg2cffi rather than psycopg2 + if (name == "psycopg2" and + platform.python_implementation() == "PyPy"): + name = "psycopg2cffi" module = importlib.import_module(name) return engine_class(module, database_config) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a6ae79dfad..8a0386c1a4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -62,3 +62,9 @@ class PostgresEngine(object): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + txn.execute("SELECT nextval('state_group_id_seq')") + return txn.fetchone()[0] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 755c9a1f07..60f0fa7fb3 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -16,6 +16,7 @@ from synapse.storage.prepare_database import prepare_database import struct +import threading class Sqlite3Engine(object): @@ -24,6 +25,11 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + def check_database(self, txn): pass @@ -43,6 +49,19 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 256e50dc20..8fbf7ffba7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,50 +12,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.storage.signatures import SignatureWorkerStore + from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 import logging -from Queue import PriorityQueue, Empty +from six.moves.queue import PriorityQueue, Empty + +from six.moves import range logger = logging.getLogger(__name__) -class EventFederationStore(SQLBaseStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, + SQLBaseStore): + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def __init__(self, hs): - super(EventFederationStore, self).__init__(hs) + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given, + ).addCallback(self._get_events) - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. - def get_auth_chain(self, event_ids): - return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def get_auth_chain_ids(self, event_ids): + Returns: + list of event_ids + """ return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_ids + event_ids, include_given ) - def _get_auth_chain_ids_txn(self, txn, event_ids): - results = set() + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" @@ -67,14 +81,14 @@ class EventFederationStore(SQLBaseStore): front_list = list(front) chunks = [ front_list[x:x + 100] - for x in xrange(0, len(front), 100) + for x in range(0, len(front), 100) ] for chunk in chunks: txn.execute( base_sql % (",".join(["?"] * len(chunk)),), chunk ) - new_front.update([r[0] for r in txn.fetchall()]) + new_front.update([r[0] for r in txn]) new_front -= results @@ -110,7 +124,7 @@ class EventFederationStore(SQLBaseStore): txn.execute(sql, (room_id, False,)) - return dict(txn.fetchall()) + return dict(txn) def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( @@ -122,7 +136,47 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -171,22 +225,6 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, @@ -198,88 +236,6 @@ class EventFederationStore(SQLBaseStore): return int(min_depth) if min_depth is not None else None - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id, _ in ev.prev_events - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany(query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ]) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -334,36 +290,13 @@ class EventFederationStore(SQLBaseStore): def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) - rows = txn.fetchall() - return [event_id for event_id, in rows] + return [event_id for event_id, in txn] return self.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = (""" - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """) - txn.execute( - sql, - (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) - ) - return self.runInteraction( - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn - ) - def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -436,7 +369,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for row in txn.fetchall(): + for row in txn: if row[1] not in event_results: queue.put((-row[0], row[1])) @@ -482,7 +415,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn.fetchall(): + for e_id, in txn: new_front.add(e_id) new_front -= earliest_events @@ -493,6 +426,135 @@ class EventFederationStore(SQLBaseStore): return event_results + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + if min_depth and depth >= min_depth: + return + + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={ + "room_id": room_id, + }, + values={ + "min_depth": depth, + }, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id, _ in ev.prev_events + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) for ev in events + if not ev.internal_metadata.is_outlier() + ] + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def clean_room_for_join(self, room_id): return self.runInteraction( "clean_room_for_join", @@ -505,3 +567,52 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id,)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + + defer.returnValue(batch_size) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 522d0114cb..c22762eb5c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,131 +14,159 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, LoggingTransaction from twisted.internet import defer +from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.types import RoomStreamToken from .stream import lower_bound import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) -class EventPushActionsStore(SQLBaseStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} +] - def __init__(self, hs): - self.stream_ordering_month_ago = None - super(EventPushActionsStore, self).__init__(hs) - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" - ) + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) - def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): - """ - Args: - event: the event set actions for - tuples: list of tuples of (user_id, actions) - """ - values = [] - for uid, actions in tuples: - values.append({ - 'room_id': event.room_id, - 'event_id': event.event_id, - 'user_id': uid, - 'actions': json.dumps(actions), - 'stream_ordering': event.internal_metadata.stream_ordering, - 'topological_ordering': event.depth, - 'notif': 1, - 'highlight': 1 if _action_has_highlight(actions) else 0, - }) - - for uid, __ in tuples: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid) - ) - self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): - def _get_unread_event_push_actions_by_room(txn): - sql = ( - "SELECT stream_ordering, topological_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute( - sql, (room_id, last_read_event_id) - ) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - topological_ordering = results[0][1] - token = RoomStreamToken( - topological_ordering, stream_ordering - ) +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return json.loads(actions) - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - notify_count = row[0] if row else 0 - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) +class EventPushActionsWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None + + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + after_callbacks=[], + exception_callbacks=[], + ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() - return { - "notify_count": notify_count, - "highlight_count": highlight_count, - } + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 + ) + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) + def get_unread_event_push_actions_by_room_for_user( + self, room_id, user_id, last_read_event_id + ): ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", - _get_unread_event_push_actions_by_room + self._get_unread_counts_by_receipt_txn, + room_id, user_id, last_read_event_id ) defer.returnValue(ret) + def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, + last_read_event_id): + sql = ( + "SELECT stream_ordering, topological_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute( + sql, (room_id, last_read_event_id) + ) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} + + stream_ordering = results[0][0] + topological_ordering = results[0][1] + + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, topological_ordering, stream_ordering + ) + + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering, + stream_ordering): + token = RoomStreamToken( + topological_ordering, stream_ordering + ) + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute(""" + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering,)) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } + @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): def f(txn): @@ -146,7 +175,7 @@ class EventPushActionsStore(SQLBaseStore): " stream_ordering >= ? AND stream_ordering <= ?" ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] ret = yield self.runInteraction("get_push_action_users_in_range", f) defer.returnValue(ret) @@ -176,7 +205,8 @@ class EventPushActionsStore(SQLBaseStore): # find rooms that have a read receipt in them and return the next # push actions sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -217,7 +247,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight " " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -246,7 +276,7 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), + "actions": _deserialize_action(row[3], row[4]), } for row in after_read_receipt + no_read_receipt ] @@ -285,7 +315,7 @@ class EventPushActionsStore(SQLBaseStore): def get_after_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -327,7 +357,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -357,8 +387,8 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), - "received_ts": row[4], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], } for row in after_read_receipt + no_read_receipt ] @@ -371,6 +401,290 @@ class EventPushActionsStore(SQLBaseStore): # Now return the first `limit` defer.returnValue(notifs[:limit]) + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany(sql, ( + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.iteritems() + )) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + @defer.inlineCallbacks + def remove_push_actions_from_staging(self, event_id): + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + + Args: + event_id (str) + """ + + try: + res = yield self._simple_delete( + table="event_push_actions_staging", + keyvalues={ + "event_id": event_id, + }, + desc="remove_push_actions_from_staging", + ) + defer.returnValue(res) + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure", + ) + + @defer.inlineCallbacks + def _find_stream_orderings_for_times(self): + yield self.runInteraction( + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", + self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, <none>, 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, + all_events_and_contexts): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany(sql, ( + ( + event.room_id, event.internal_metadata.stream_ordering, + event.depth, event.event_id, + ) + for event, _ in events_and_contexts + )) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + }, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid,) + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ( + (event.event_id,) + for event, _ in all_events_and_contexts + ) + ) + @defer.inlineCallbacks def get_push_actions_for_user(self, user_id, before=None, limit=50, only_highlight=False): @@ -392,7 +706,7 @@ class EventPushActionsStore(SQLBaseStore): sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.profile_tag, e.received_ts" + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" @@ -407,7 +721,7 @@ class EventPushActionsStore(SQLBaseStore): "get_push_actions_for_user", f ) for pa in push_actions: - pa["actions"] = json.loads(pa["actions"]) + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) defer.returnValue(push_actions) @defer.inlineCallbacks @@ -448,7 +762,7 @@ class EventPushActionsStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - topological_ordering): + topological_ordering, stream_ordering): """ Purges old push actions for a user and room before a given topological_ordering. @@ -479,65 +793,140 @@ class EventPushActionsStore(SQLBaseStore): txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " - " topological_ordering < ?" + " topological_ordering <= ?" " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) ) + txn.execute(""" + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, (room_id, user_id, stream_ordering)) + @defer.inlineCallbacks - def _find_stream_orderings_for_times(self): - yield self.runInteraction( - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn - ) + def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", - self.stream_ordering_month_ago + try: + while True: + logger.info("Rotating notifications") + + caught_up = yield self.runInteraction( + "_rotate_notifs", + self._rotate_notifs_txn + ) + if caught_up: + break + yield sleep(5) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", ) - def _find_first_stream_ordering_after_ts_txn(self, txn, ts): - """ - Find the stream_ordering of the first event that was received after - a given timestamp. This is relatively slow as there is no index on - received_ts but we can then use this to delete push actions before - this. + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute(""" + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000 + """, (old_rotate_stream_ordering,)) + stream_row = txn.fetchone() + if stream_row: + offset_stream_ordering, = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - if max_stream_ordering is None: - return 0 + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self._simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows if row[4] is None + ] + ) - range_start = 0 - range_end = max_stream_ordering + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) + ) - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering > ?" - " ORDER BY stream_ordering" - " LIMIT 1" + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering,) ) - while range_end - range_start > 1: - middle = int((range_end + range_start) / 2) - txn.execute(sql, (middle,)) - middle_ts = txn.fetchone()[0] - if ts > middle_ts: - range_start = middle - else: - range_end = middle + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - return range_end + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,) + ) def _action_has_highlight(actions): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c88f689d3a..05cde96afc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,59 +13,61 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer, reactor +from collections import OrderedDict, deque, namedtuple +from functools import wraps +import itertools +import logging -from synapse.events import FrozenEvent, USE_FROZEN_DICTS -from synapse.events.utils import prune_event +import simplejson as json +from twisted.internet import defer +from synapse.storage.events_worker import EventsWorkerStore from synapse.util.async import ObservableDeferred +from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred + PreserveLoggingContext, make_deferred_yieldable, ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -from synapse.state import resolve_events -from synapse.util.caches.descriptors import cached - -from canonicaljson import encode_canonical_json -from collections import deque, namedtuple, OrderedDict -from functools import wraps - -import synapse +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.types import get_domain_from_id import synapse.metrics - -import logging -import math -import ujson as json +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) persist_event_counter = metrics.register_counter("persisted_events") +event_counter = metrics.register_counter( + "persisted_events_sep", labels=["type", "origin_type", "origin_entity"] +) - -def encode_json(json_object): - if USE_FROZEN_DICTS: - # ujson doesn't like frozen_dicts - return encode_canonical_json(json_object) - else: - return json.dumps(json_object, ensure_ascii=False) +# The number of times we are recalculating the current state +state_delta_counter = metrics.register_counter( + "state_delta", +) +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = metrics.register_counter( + "state_delta_single_event", +) +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = metrics.register_counter( + "state_delta_reuse_delta", +) -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +def encode_json(json_object): + return frozendict_json_encoder.encode(json_object) class _EventPeristenceQueue(object): @@ -82,15 +85,30 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() - deferred = ObservableDeferred(defer.Deferred()) + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, @@ -103,11 +121,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -126,18 +144,25 @@ class _EventPeristenceQueue(object): try: queue = self._get_drainining_queue(room_id) for item in queue: + # handle_queue_loop runs in the sentinel logcontext, so + # there is no need to preserve_fn when running the + # callbacks on the deferred. try: ret = yield per_item_callback(item) item.deferred.callback(ret) - except Exception as e: - item.deferred.errback(e) + except Exception: + item.deferred.errback() finally: queue = self._event_persist_queues.pop(room_id, None) if queue: self._event_persist_queues[room_id] = queue self._currently_persisting_rooms.discard(room_id) - preserve_fn(handle_queue_loop)() + # set handle_queue_loop off on the background. We don't want to + # attribute work done in it to the current request, so we drop the + # logcontext altogether. + with PreserveLoggingContext(): + handle_queue_loop() def _get_drainining_queue(self, room_id): queue = self._event_persist_queues.setdefault(room_id, deque()) @@ -173,13 +198,12 @@ def _retry_on_integrity_error(func): return f -class EventsStore(SQLBaseStore): +class EventsStore(EventsWorkerStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - def __init__(self, hs): - super(EventsStore, self).__init__(hs) - self._clock = hs.get_clock() + def __init__(self, db_conn, hs): + super(EventsStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) @@ -196,8 +220,22 @@ class EventsStore(SQLBaseStore): where_clause="contains_url = true AND outlier = false", ) + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database @@ -210,23 +248,34 @@ class EventsStore(SQLBaseStore): partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in partitioned.items(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + for room_id, evs_ctxs in partitioned.iteritems(): + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) deferreds.append(d) - for room_id in partitioned.keys(): + for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, @@ -234,7 +283,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -242,10 +291,11 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -253,6 +303,16 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _persist_events(self, events_and_contexts, backfilled=False, delete_existing=False): + """Persist events to db + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ if not events_and_contexts: return @@ -282,8 +342,20 @@ class EventsStore(SQLBaseStore): # NB: Assumes that we are only persisting events for one room # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where each entry is + # a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. @@ -295,7 +367,7 @@ class EventsStore(SQLBaseStore): (event, context) ) - for room_id, ev_ctx_rm in events_by_room.items(): + for room_id, ev_ctx_rm in events_by_room.iteritems(): # Work out new extremities by recursively adding and removing # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( @@ -305,17 +377,64 @@ class EventsStore(SQLBaseStore): room_id, ev_ctx_rm, latest_event_ids ) - if new_latest_event_ids == set(latest_event_ids): + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue new_forward_extremeties[room_id] = new_latest_event_ids - state = yield self._calculate_state_delta( - room_id, ev_ctx_rm, new_latest_event_ids + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_events) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(e for e, _ in ev.prev_events) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info( + "Calculating state delta for room %s", room_id, ) - if state: - current_state_for_room[room_id] = state + current_state = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + if current_state is not None: + current_state_for_room[room_id] = current_state + delta = yield self._calculate_state_delta( + room_id, current_state, + ) + if delta is not None: + state_delta_for_room[room_id] = delta yield self.runInteraction( "persist_events", @@ -323,10 +442,35 @@ class EventsStore(SQLBaseStore): events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, - current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) + for event, context in chunk: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.inc(event.type, origin_type, origin_entity) + + for room_id, new_state in current_state_for_room.iteritems(): + self.get_current_state_ids.prefill( + (room_id, ), new_state + ) + + for room_id, latest_event_ids in new_forward_extremeties.iteritems(): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): @@ -370,71 +514,137 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): - """Calculate the new state deltas for a room. + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): + """Calculate the current state dict after adding some new events to + a room - Assumes that we are only persisting events for one room at a time. + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. Returns: - 2-tuple (to_delete, to_insert) where both are state dicts, i.e. - (type, state_key) -> event_id. `to_delete` are the entries to - first be deleted from current_state_events, `to_insert` are entries - to insert. - May return None if there are no changes to be applied. + Deferred[dict[(str,str), str]|None]: + None if there are no changes to the room state, or + a dict of (type, state_key) -> event_id]. """ - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - missing_event_ids = [] - was_updated = False + + if not new_latest_event_ids: + return + + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + state_groups_map[ctx.state_group] = ctx.current_state_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) - groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) - state_sets.extend(group_to_state.values()) + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - if not new_latest_event_ids: - current_state = {} - elif was_updated: - current_state = yield resolve_events( - state_sets, - state_map_factory=lambda ev_ids: self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), - ) - else: + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: return - existing_state_rows = yield self._simple_select_list( - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - desc="_calculate_state_delta", + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + defer.returnValue(state_groups_map[new_state_groups.pop()]) + + # Ok, we need to defer to the state handler to resolve our state sets. + + def get_events(ev_ids): + return self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups + } + + events_map = {ev.event_id: ev for ev, _ in events_context} + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, state_groups, events_map, get_events ) - existing_events = set(row["event_id"] for row in existing_state_rows) + defer.returnValue(res.state) + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, + i.e. (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + """ + existing_state = yield self.get_current_state_ids(room_id) + + existing_events = set(existing_state.itervalues()) new_events = set(ev_id for ev_id in current_state.itervalues()) changed_events = existing_events ^ new_events @@ -442,9 +652,8 @@ class EventsStore(SQLBaseStore): return to_delete = { - (row["type"], row["state_key"]): row["event_id"] - for row in existing_state_rows - if row["event_id"] in changed_events + key: ev_id for key, ev_id in existing_state.iteritems() + if ev_id in changed_events } events_to_insert = (new_events - existing_events) to_insert = { @@ -454,77 +663,104 @@ class EventsStore(SQLBaseStore): defer.returnValue((to_delete, to_insert)) - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. + @log_function + def _persist_events_txn(self, txn, events_and_contexts, backfilled, + delete_existing=False, state_delta_for_room={}, + new_forward_extremeties={}): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): + events to persist + backfilled (bool): True if the events were backfilled + delete_existing (bool): True to purge existing table rows for the + events from the database. This is useful when retrying due to + IntegrityError. + state_delta_for_room (dict[str, (list[str], list[str])]): + The current-state delta for each room. For each room, a tuple + (to_delete, to_insert), being a list of event ids to be removed + from the current state, and a list of event ids to be added to + the current state. + new_forward_extremeties (dict[str, list[str]]): + The new forward extremities for each room. For each room, a + list of the event ids which are the forward extremities. - Returns: - Deferred : A FrozenEvent. """ - events = yield self._get_events( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + all_events_and_contexts = events_and_contexts + + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) + + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, ) - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts, + ) - defer.returnValue(events[0] if events else None) + self._update_room_depths_txn( + txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + ) - @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - """Get events from the database + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, + events_and_contexts=events_and_contexts, + ) - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + # From this point onwards the events are only events that we haven't + # seen before. - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + self._delete_existing_rows_txn( + txn, + events_and_contexts=events_and_contexts, + ) + + self._store_event_txn( + txn, + events_and_contexts=events_and_contexts, ) - defer.returnValue({e.event_id: e for e in events}) + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) - @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, current_state_for_room={}, - new_forward_extremeties={}): - """Insert some number of room events into the necessary database tables. + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, + events_and_contexts=events_and_contexts, + ) - Rejected events are only inserted into the events table, the events_json table, - and the rejections table. Things reading from those table will need to check - whether the event was rejected. + # From this point onwards the events are only ones that weren't + # rejected. - If delete_existing is True then existing events will be purged from the - database before insertion. This is useful when retrying due to IntegrityError. - """ - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - for room_id, current_state_tuple in current_state_for_room.iteritems(): + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): + for room_id, current_state_tuple in state_delta_by_room.iteritems(): to_delete, to_insert = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", @@ -545,6 +781,29 @@ class EventsStore(SQLBaseStore): ], ) + state_deltas = {key: None for key in to_delete} + state_deltas.update(to_insert) + + self._simple_insert_many_txn( + txn, + table="current_state_delta_stream", + values=[ + { + "stream_id": max_stream_order, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": ev_id, + "prev_event_id": to_delete.get(key, None), + } + for key, ev_id in state_deltas.iteritems() + ] + ) + + self._curr_state_delta_stream_cache.entity_has_changed( + room_id, max_stream_order, + ) + # Invalidate the various caches # Figure out the changes of membership to invalidate the @@ -553,24 +812,34 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - state_key for ev_type, state_key in to_delete.iterkeys() - if ev_type == EventTypes.Member - ) - members_changed.update( - state_key for ev_type, state_key in to_insert.iterkeys() + state_key for ev_type, state_key in state_deltas if ev_type == EventTypes.Member ) for member in members_changed: self._invalidate_cache_and_stream( - txn, self.get_rooms_for_user, (member,) + txn, self.get_rooms_for_user_with_stream_ordering, (member,) + ) + + for host in set(get_domain_from_id(u) for u in members_changed): + self._invalidate_cache_and_stream( + txn, self.is_host_joined, (room_id, host) + ) + self._invalidate_cache_and_stream( + txn, self.was_host_joined, (room_id, host) ) self._invalidate_cache_and_stream( txn, self.get_users_in_room, (room_id,) ) - for room_id, new_extrem in new_forward_extremeties.items(): + self._invalidate_cache_and_stream( + txn, self.get_current_state_ids, (room_id,) + ) + + def _update_forward_extremities_txn(self, txn, new_forward_extremities, + max_stream_order): + for room_id, new_extrem in new_forward_extremities.iteritems(): self._simple_delete_txn( txn, table="event_forward_extremities", @@ -588,7 +857,7 @@ class EventsStore(SQLBaseStore): "event_id": ev_id, "room_id": room_id, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for ev_id in new_extrem ], ) @@ -605,13 +874,22 @@ class EventsStore(SQLBaseStore): "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for event_id in new_extrem ] ) - # Ensure that we don't have the same event twice. - # Pick the earliest non-outlier if there is one, else the earliest one. + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ new_events_and_contexts = OrderedDict() for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) @@ -624,9 +902,17 @@ class EventsStore(SQLBaseStore): new_events_and_contexts[event.event_id] = (event, context) else: new_events_and_contexts[event.event_id] = (event, context) + return new_events_and_contexts.values() - events_and_contexts = new_events_and_contexts.values() + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -642,9 +928,24 @@ class EventsStore(SQLBaseStore): event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in depth_updates.items(): + for room_id, depth in depth_updates.iteritems(): self._update_min_depth_for_room_txn(txn, room_id, depth) + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ txn.execute( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( ",".join(["?"] * len(events_and_contexts)), @@ -654,34 +955,30 @@ class EventsStore(SQLBaseStore): have_persisted = { event_id: outlier - for event_id, outlier in txn.fetchall() + for event_id, outlier in txn } to_remove = set() for event, context in events_and_contexts: - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - if event.event_id in have_persisted: - # If we have already seen the event then ignore it. - to_remove.add(event) - continue - if event.event_id not in have_persisted: continue to_remove.add(event) + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as # an outlier in the database. We now have some state at that # so we need to update the state_groups table with that state. - # insert into the state_group, state_groups_state and - # event_to_state_groups tables. + # insert into event_to_state_groups. try: - self._store_mult_state_groups_txn(txn, ((event, context),)) + self._store_event_state_mappings_txn(txn, ((event, context),)) except Exception: logger.exception("") raise @@ -726,37 +1023,19 @@ class EventsStore(SQLBaseStore): # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + @classmethod + def _delete_existing_rows_txn(cls, txn, events_and_contexts): if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only events that we haven't - # seen before. - - def event_dict(event): - return { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } - - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - - logger.info("Deleting existing") + logger.info("Deleting existing") - for table in ( + for table in ( "events", "event_auth", "event_json", @@ -779,11 +1058,30 @@ class EventsStore(SQLBaseStore): "redactions", "room_memberships", "topics" - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] - ) + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts] + ) + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d self._simple_insert_many_txn( txn, @@ -827,6 +1125,19 @@ class EventsStore(SQLBaseStore): ], ) + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. to_remove = set() @@ -838,24 +1149,37 @@ class EventsStore(SQLBaseStore): ) to_remove.add(event) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + def _update_metadata_tables_txn(self, txn, events_and_contexts, + all_events_and_contexts, backfilled): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + backfilled (bool): True if the events were backfilled + """ + + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only ones that weren't rejected. - for event, context in events_and_contexts: - # Insert all the push actions into the event_push_actions table. - if context.push_actions: - self._set_push_actions_for_event_and_users_txn( - txn, event, context.push_actions - ) - if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. @@ -874,13 +1198,10 @@ class EventsStore(SQLBaseStore): } for event, _ in events_and_contexts for auth_id, _ in event.auth_events + if event.is_state() ], ) - # Insert into the state_groups, state_groups_state, and - # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, events_and_contexts) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -967,13 +1288,6 @@ class EventsStore(SQLBaseStore): # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - if backfilled: - # Backfilled events come before the current state so we don't need - # to update the current state table - return - - return - def _add_to_cache(self, txn, events_and_contexts): to_prefill = [] @@ -1037,13 +1351,49 @@ class EventsStore(SQLBaseStore): defer.returnValue(set(r["event_id"] for r in rows)) - def have_events(self, event_ids): + @defer.inlineCallbacks + def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. + Args: + event_ids (iterable[str]): + Returns: - dict: Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps to - None. + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. """ if not event_ids: return defer.succeed({}) @@ -1065,280 +1415,7 @@ class EventsStore(SQLBaseStore): return res - return self.runInteraction( - "have_events", f, - ) - - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - defer.returnValue([]) - - event_id_list = event_ids - event_ids = set(event_ids) - - event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - allow_rejected=allow_rejected, - ) - - event_entry_map.update(missing_events) - - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected): - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get((event_id,), None) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - event_list = [] - i = 0 - while True: - try: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - event_id_lists = zip(*event_list)[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] - - rows = self._new_transaction( - conn, "do_fetch", [], None, self._fetch_event_rows, event_ids - ) - - row_dict = { - r["event_id"]: r - for r in rows - } - - # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) - except: - logger.exception("Failed to callback") - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(e) - - if event_list: - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) - - @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - """ - if not events: - defer.returnValue({}) - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - with PreserveLoggingContext(): - self.runWithConnection( - self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self._get_event_from_row)( - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - ) - for row in rows - ], - consumeErrors=True - )) - - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) - - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i * N:(i + 1) * N] - if not evs: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " e.internal_metadata," - " e.json," - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - - return rows - - @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - rejected_reason=None): - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - - original_ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - defer.returnValue(cache_entry) + return self.runInteraction("get_rejection_reasons", f) @defer.inlineCallbacks def count_daily_messages(self): @@ -1349,66 +1426,52 @@ class EventsStore(SQLBaseStore): call to this function, it will return None. """ def _count_messages(txn): - now = self.hs.get_clock().time() - - txn.execute( - "SELECT reported_stream_token, reported_time FROM stats_reporting" - ) - last_reported = self.cursor_to_dict(txn) + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count - txn.execute( - "SELECT stream_ordering" - " FROM events" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) - now_reporting = self.cursor_to_dict(txn) - if not now_reporting: - logger.info("Calculating daily messages skipped; no now_reporting") - return None - now_reporting = now_reporting[0]["stream_ordering"] - - txn.execute("DELETE FROM stats_reporting") - txn.execute( - "INSERT INTO stats_reporting" - " (reported_stream_token, reported_time)" - " VALUES (?, ?)", - (now_reporting, now,) - ) - - if not last_reported: - logger.info("Calculating daily messages skipped; no last_reported") - return None - - # Close enough to correct for our purposes. - yesterday = (now - 24 * 60 * 60) - since_yesterday_seconds = yesterday - last_reported[0]["reported_time"] - any_since_yesterday = math.fabs(since_yesterday_seconds) > 60 * 60 - if any_since_yesterday: - logger.info( - "Calculating daily messages skipped; since_yesterday_seconds: %d" % - (since_yesterday_seconds,) - ) - return None + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) - txn.execute( - "SELECT COUNT(*) as messages" - " FROM events NATURAL JOIN event_json" - " WHERE json like '%m.room.message%'" - " AND stream_ordering > ?" - " AND stream_ordering <= ?", - ( - last_reported[0]["reported_stream_token"], - now_reporting, - ) - ) - rows = self.cursor_to_dict(txn) - if not rows: - logger.info("Calculating daily messages skipped; messages count missing") - return None - return rows[0]["messages"] + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + defer.returnValue(ret) - ret = yield self.runInteraction("count_messages", _count_messages) + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_active_rooms", _count) defer.returnValue(ret) @defer.inlineCallbacks @@ -1569,6 +1632,94 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + return self.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + return self.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): @@ -1582,14 +1733,13 @@ class EventsStore(SQLBaseStore): def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" " LIMIT ?" ) if have_forward_events: @@ -1615,15 +1765,13 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = [] sql = ( - "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," - " eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" - " ORDER BY e.stream_ordering DESC" + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" " LIMIT ?" ) if have_backfill_events: @@ -1654,16 +1802,32 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def delete_old_state(self, room_id, topological_ordering): - return self.runInteraction( - "delete_old_state", - self._delete_old_state_txn, room_id, topological_ordering - ) + def purge_history( + self, room_id, topological_ordering, delete_local_events, + ): + """Deletes room history before a certain point + + Args: + room_id (str): - def _delete_old_state_txn(self, txn, room_id, topological_ordering): - """Deletes old room state + topological_ordering (int): + minimum topo ordering to preserve + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). """ + return self.runInteraction( + "purge_history", + self._purge_history_txn, room_id, topological_ordering, + delete_local_events, + ) + + def _purge_history_txn( + self, txn, room_id, topological_ordering, delete_local_events, + ): # Tables that should be pruned: # event_auth # event_backward_extremities @@ -1684,6 +1848,30 @@ class EventsStore(SQLBaseStore): # state_groups # state_groups_state + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -1704,29 +1892,49 @@ class EventsStore(SQLBaseStore): 400, "topological_ordering is greater than forward extremeties" ) + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + should_delete_params += ("%:" + self.hs.hostname, ) + + should_delete_params += (room_id, topological_ordering) + + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE e.room_id = ? AND topological_ordering < ?" % ( + should_delete_expr, + ), + should_delete_params, + ) txn.execute( - "SELECT event_id, state_key FROM events" - " LEFT JOIN state_events USING (room_id, event_id)" - " WHERE room_id = ? AND topological_ordering < ?", - (room_id, topological_ordering,) + "SELECT event_id, should_delete FROM events_to_purge" ) event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), sum(1 for e in event_rows if e[1]), + ) - for event_id, state_key in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + logger.info("[purge] Finding new backward extremities") # We calculate the new entries for the backward extremeties by finding # all events that point to events that are to be purged txn.execute( - "SELECT DISTINCT e.event_id FROM events as e" - " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id" - " INNER JOIN events as e2 ON e2.event_id = ed.event_id" - " WHERE e.room_id = ? AND e.topological_ordering < ?" - " AND e2.topological_ordering >= ?", - (room_id, topological_ordering, topological_ordering) + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " INNER JOIN events AS e2 ON e2.event_id = ed.event_id" + " WHERE e2.topological_ordering >= ?", + (topological_ordering, ) ) new_backwards_extrems = txn.fetchall() + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + txn.execute( "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) @@ -1741,30 +1949,36 @@ class EventsStore(SQLBaseStore): ] ) + logger.info("[purge] finding redundant state groups") + # Get all state groups that are only referenced by events that are # to be deleted. txn.execute( "SELECT state_group FROM event_to_state_groups" " INNER JOIN events USING (event_id)" " WHERE state_group IN (" - " SELECT DISTINCT state_group FROM events" + " SELECT DISTINCT state_group FROM events_to_purge" " INNER JOIN event_to_state_groups USING (event_id)" - " WHERE room_id = ? AND topological_ordering < ?" " )" " GROUP BY state_group HAVING MAX(topological_ordering) < ?", - (room_id, topological_ordering, topological_ordering) + (topological_ordering, ) ) state_rows = txn.fetchall() - state_groups_to_delete = [sg for sg, in state_rows] + logger.info("[purge] found %i redundant state groups", len(state_rows)) + + # make a set of the redundant state groups, so that we can look them up + # efficiently + state_groups_to_delete = set([sg for sg, in state_rows]) # Now we get all the state groups that rely on these state groups - new_state_edges = [] - chunks = [ - state_groups_to_delete[i:i + 100] - for i in xrange(0, len(state_groups_to_delete), 100) - ] - for chunk in chunks: + logger.info("[purge] finding state groups which depend on redundant" + " state groups") + remaining_state_groups = [] + for i in xrange(0, len(state_rows), 100): + chunk = [sg for sg, in state_rows[i:i + 100]] + # look for state groups whose prev_state_group is one we are about + # to delete rows = self._simple_select_many_txn( txn, table="state_group_edges", @@ -1773,21 +1987,28 @@ class EventsStore(SQLBaseStore): retcols=["state_group"], keyvalues={}, ) - new_state_edges.extend(row["state_group"] for row in rows) + remaining_state_groups.extend( + row["state_group"] for row in rows + + # exclude state groups we are about to delete: no point in + # updating them + if row["state_group"] not in state_groups_to_delete + ) - # Now we turn the state groups that reference to-be-deleted state groups - # to non delta versions. - for new_state_edge in new_state_edges: + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [new_state_edge], types=None + txn, [sg], types=None ) - curr_state = curr_state[new_state_edge] + curr_state = curr_state[sg] self._simple_delete_txn( txn, table="state_groups_state", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -1795,7 +2016,7 @@ class EventsStore(SQLBaseStore): txn, table="state_group_edges", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -1804,16 +2025,17 @@ class EventsStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": new_state_edge, + "state_group": sg, "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in curr_state.items() + for key, state_id in curr_state.iteritems() ], ) + logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows @@ -1822,22 +2044,18 @@ class EventsStore(SQLBaseStore): "DELETE FROM state_groups WHERE id = ?", state_rows ) - # Delete all non-state - txn.executemany( - "DELETE FROM event_to_state_groups WHERE event_id = ?", - [(event_id,) for event_id, _ in event_rows] - ) + logger.info("[purge] removing events from event_to_state_groups") txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (topological_ordering, room_id,) + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, ( + event_id, + )) # Delete all remote non-state events - to_delete = [ - (event_id,) for event_id, state_key in event_rows - if state_key is None and not self.hs.is_mine_id(event_id) - ] for table in ( "events", "event_json", @@ -1847,29 +2065,102 @@ class EventsStore(SQLBaseStore): "event_edge_hashes", "event_edges", "event_forward_extremities", - "event_push_actions", "event_reference_hashes", "event_search", "event_signatures", "rejections", ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - to_delete + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ( + "event_push_actions", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id, ) ) - txn.executemany( - "DELETE FROM events WHERE event_id = ?", - to_delete - ) # Mark all state and own events as outliers - txn.executemany( + logger.info("[purge] marking remaining events as outliers") + txn.execute( "UPDATE events SET outlier = ?" - " WHERE event_id = ?", - [ - (True, event_id,) for event_id, state_key in event_rows - if state_key is not None or self.hs.is_mine_id(event_id) - ] + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + logger.info("[purge] updating room_depth") + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (topological_ordering, room_id,) + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute( + "DROP TABLE events_to_purge" + ) + + logger.info("[purge] done") + + @defer.inlineCallbacks + def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = yield self._get_event_ordering(event_id1) + to_2, so_2 = yield self._get_event_ordering(event_id2) + defer.returnValue((to_1, so_1) > (to_2, so_2)) + + @cachedInlineCallbacks(max_entries=5000) + def _get_event_ordering(self, event_id): + res = yield self._simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) + + def get_max_current_state_delta_stream_id(self): + return self._stream_id_gen.get_current_token() + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, ) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py new file mode 100644 index 0000000000..ba834854e1 --- /dev/null +++ b/synapse/storage/events_worker.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + +from twisted.internet import defer, reactor + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import ( + PreserveLoggingContext, make_deferred_yieldable, run_in_background, +) +from synapse.util.metrics import Measure +from synapse.api.errors import SynapseError + +from collections import namedtuple + +import logging +import simplejson as json + +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={ + "event_id": event_id, + }, + retcol="received_ts", + desc="get_received_ts", + ) + + @defer.inlineCallbacks + def get_event(self, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False, + allow_none=False): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw an exception. + + Returns: + Deferred : A FrozenEvent. + """ + events = yield self._get_events( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + if not events and not allow_none: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue(events[0] if events else None) + + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_id_list = event_ids + event_ids = set(event_ids) + + event_entry_map = self._get_events_from_cache( + event_ids, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) + + event_entry_map.update(missing_events) + + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) + except Exception: + logger.exception("Failed to callback") + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(e) + + if event_list: + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + with PreserveLoggingContext(): + self.runWithConnection( + self._do_fetch + ) + + logger.debug("Loading %d events", len(events)) + with PreserveLoggingContext(): + rows = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield make_deferred_yieldable(defer.gatherResults( + [ + run_in_background( + self._get_event_from_row, + row["internal_metadata"], row["json"], row["redacts"], + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + )) + + defer.returnValue({ + e.event.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i * N:(i + 1) * N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + rejected_reason=None): + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) + + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) + + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) + + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because + + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, + ) + + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) + + defer.returnValue(cache_entry) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index a2ccc66ea7..78b1e30945 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -19,6 +19,7 @@ from ._base import SQLBaseStore from synapse.api.errors import SynapseError, Codes from synapse.util.caches.descriptors import cachedInlineCallbacks +from canonicaljson import encode_canonical_json import simplejson as json @@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore): defer.returnValue(json.loads(str(def_json).decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): - def_json = json.dumps(user_filter).encode("utf-8") + def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one def _do_txn(txn): sql = ( + "SELECT filter_id FROM user_filters " + "WHERE user_id = ? AND filter_json = ?" + ) + txn.execute(sql, (user_localpart, def_json)) + filter_id_response = txn.fetchone() + if filter_id_response is not None: + return filter_id_response[0] + + sql = ( "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" ) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py new file mode 100644 index 0000000000..da05ccb027 --- /dev/null +++ b/synapse/storage/group_server.py @@ -0,0 +1,1253 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError + +from ._base import SQLBaseStore + +import simplejson as json + + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerStore(SQLBaseStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues={ + "join_policy": join_policy, + }, + desc="set_group_join_policy", + ) + + def get_group(self, group_id): + return self._simple_select_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + retcols=( + "name", "short_description", "long_description", + "avatar_url", "is_public", "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin",), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + }, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_rooms", + keyvalues=keyvalues, + retcols=("room_id", "is_public",), + desc="get_rooms_in_group", + ) + + def get_rooms_for_summary_by_category(self, group_id, include_private=False): + """Get the rooms and categories that should be included in a summary request + + Returns ([rooms], [categories]) + """ + def _get_rooms_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + return self.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.runInteraction( + "add_room_to_summary", self._add_room_to_summary_txn, + group_id, room_id, category_id, order, is_public, + ) + + def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, + is_public): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, (group_id, category_id, group_id, category_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self._simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self._simple_select_list( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + }, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + defer.returnValue({ + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self._simple_select_one( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + defer.returnValue(category) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self._simple_delete( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + desc="remove_group_category", + ) + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self._simple_select_list( + table="group_roles", + keyvalues={ + "group_id": group_id, + }, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + defer.returnValue({ + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self._simple_select_one( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + defer.returnValue(role) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self._simple_delete( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.runInteraction( + "add_user_to_summary", self._add_user_to_summary_txn, + group_id, user_id, role_id, order, is_public, + ) + + def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, + is_public): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, (group_id, role_id, group_id, role_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + "role_id": role_id, + }, + retcols=("user_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self._simple_delete( + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + desc="remove_user_from_summary", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + def _get_users_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + return self.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self._simple_insert( + table="group_invites", + values={ + "group_id": group_id, + "user_id": user_id, + }, + desc="add_group_invite", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self._simple_select_one_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + def _get_users_membership_in_group_txn(txn): + row = self._simple_select_one_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self._simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + + if row: + return { + "membership": "invite", + } + + return {} + + return self.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn, + ) + + def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, + local_attestation=None, remote_attestation=None): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + def _add_user_to_group_txn(txn): + self._simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.runInteraction( + "add_user_to_group", _add_user_to_group_txn + ) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) + + def add_room_to_group(self, group_id, room_id, is_public): + return self._simple_insert( + table="group_rooms", + values={ + "group_id": group_id, + "room_id": room_id, + "is_public": is_public, + }, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self._simple_update( + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + updatevalues={ + "is_public": is_public, + }, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + + self._simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + return self.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn, + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + "is_publicised": True, + }, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self._simple_update_one( + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "is_publicised": publicise, + }, + desc="update_group_publicity" + ) + + @defer.inlineCallbacks + def register_user_group_membership(self, group_id, user_id, membership, + is_admin=False, content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self._simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self._simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps({"membership": membership, "content": content}), + } + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + } + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + } + ) + else: + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, next_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, name, avatar_url, short_description, + long_description,): + yield self._simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile,): + yield self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues=profile, + desc="update_group_profile", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.cursor_to_dict(txn) + return self.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self._simple_update_one( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + }, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self._simple_update_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation) + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self._simple_select_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + defer.returnValue(json.loads(row["attestation_json"])) + + defer.returnValue(None) + + def get_joined_groups(self, user_id): + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + }, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token,)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + return self.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn, + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token, + ) + if not has_changed: + return [] + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token,)) + return [{ + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } for group_id, membership, gtype, content_json in txn] + return self.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn, + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token, + ) + if not has_changed: + return [] + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit,)) + return [( + stream_id, + group_id, + user_id, + gtype, + json.loads(content_json), + ) for stream_id, group_id, user_id, gtype, content_json in txn] + return self.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn, + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 86b37b9ddd..87aeaf71d6 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): key_ids Args: server_name (str): The name of the server. - key_ids (list of str): List of key_ids to try and look up. + key_ids (iterable[str]): key_ids to try and look up. Returns: - (list of VerifyKey): The verification keys. + Deferred: resolves to dict[str, VerifyKey]: map from + key_id to verification key. """ keys = {} for key_id in key_ids: @@ -112,30 +113,37 @@ class KeyStore(SQLBaseStore): keys[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. - key_id (str): The version of the key for the server. from_server (str): Where the verification key was looked up - ts_now_ms (int): The time now in milliseconds - verification_key (VerifyKey): The NACL verify key. + time_now_ms (int): The time now in milliseconds + verify_key (nacl.signing.VerifyKey): The NACL verify key. """ - yield self._simple_upsert( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": "%s:%s" % (verify_key.alg, verify_key.version), - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": buffer(verify_key.encode()), - }, - desc="store_server_verify_key", - ) + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + def _txn(txn): + self._simple_upsert_txn( + txn, + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + }, + values={ + "from_server": from_server, + "ts_added_ms": time_now_ms, + "verify_key": buffer(verify_key.encode()), + }, + ) + txn.call_after( + self._get_server_verify_key.invalidate, + (server_name, key_id) + ) + + return self.runInteraction("store_server_verify_key", _txn) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 4c0f82353d..e6cdbb0545 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -12,15 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore -from ._base import SQLBaseStore - -class MediaRepositoryStore(SQLBaseStore): +class MediaRepositoryStore(BackgroundUpdateStore): """Persistence for attachments and avatars""" - def get_default_thumbnails(self, top_level_type, sub_type): - return [] + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name='local_media_repository_url_idx', + index_name='local_media_repository_url_idx', + table='local_media_repository', + columns=['created_ts'], + where_clause='url_cache IS NOT NULL', + ) def get_local_media(self, media_id): """Get the metadata for a local piece of media @@ -30,13 +37,16 @@ class MediaRepositoryStore(SQLBaseStore): return self._simple_select_one( "local_media_repository", {"media_id": media_id}, - ("media_type", "media_length", "upload_name", "created_ts"), + ( + "media_type", "media_length", "upload_name", "created_ts", + "quarantined_by", "url_cache", + ), allow_none=True, desc="get_local_media", ) def store_local_media(self, media_id, media_type, time_now_ms, upload_name, - media_length, user_id): + media_length, user_id, url_cache=None): return self._simple_insert( "local_media_repository", { @@ -46,6 +56,7 @@ class MediaRepositoryStore(SQLBaseStore): "upload_name": upload_name, "media_length": media_length, "user_id": user_id.to_string(), + "url_cache": url_cache, }, desc="store_local_media", ) @@ -58,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" @@ -70,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore): # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" @@ -82,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore): return None return dict(zip(( - 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' ), row)) return self.runInteraction( "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, + def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, download_ts): return self._simple_insert( "local_media_repository_url_cache", @@ -97,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore): "url": url, "response_code": response_code, "etag": etag, - "expires": expires, + "expires_ts": expires_ts, "og": og, "media_id": media_id, "download_ts": download_ts, @@ -138,7 +149,7 @@ class MediaRepositoryStore(SQLBaseStore): {"media_origin": origin, "media_id": media_id}, ( "media_type", "media_length", "upload_name", "created_ts", - "filesystem_id", + "filesystem_id", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", @@ -162,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, origin_id_tuples, time_ts): + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ def update_cache_txn(txn): sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" @@ -170,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore): ) txn.executemany(sql, ( - (time_ts, media_origin, media_id) - for media_origin, media_id in origin_id_tuples + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + )) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ( + (time_ms, media_id) + for media_id in local_media )) return self.runInteraction("update_cached_last_access_time", update_cache_txn) @@ -234,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore): }, ) return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = ( + "DELETE FROM local_media_repository_url_cache" + " WHERE media_id = ?" + ) + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn, + ) + + def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = ( + "DELETE FROM local_media_repository" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = ( + "DELETE FROM local_media_repository_thumbnails" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn, + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b357f22be7..04411a665f 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 40 +SCHEMA_VERSION = 48 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +45,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,9 +72,13 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() - except: + except Exception: db_conn.rollback() raise @@ -283,6 +295,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment @@ -356,7 +427,7 @@ def _get_or_create_schema_state(txn, database_engine): ), (current_version,) ) - applied_deltas = [d for d, in txn.fetchall()] + applied_deltas = [d for d, in txn] return current_version, applied_deltas, upgraded return None diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4d1590d2b4..9e9d3c2591 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore): self.presence_stream_cache.entity_has_changed, state.user_id, stream_id, ) - self._invalidate_cache_and_stream( - txn, self._get_presence_for_user, (state.user_id,) + txn.call_after( + self._get_presence_for_user.invalidate, (state.user_id,) ) # Actually insert new rows diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..8612bd5ecc 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,15 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.storage.roommember import ProfileInfo +from synapse.api.errors import StoreError + from ._base import SQLBaseStore -class ProfileStore(SQLBaseStore): - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", +class ProfileWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) ) def get_profile_displayname(self, user_localpart): @@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore): desc="get_profile_displayname", ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - def get_profile_avatar_url(self, user_localpart): return self._simple_select_one_onecol( table="profiles", @@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore): desc="get_profile_avatar_url", ) + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url",), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + +class ProfileStore(ProfileWorkerStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + desc="create_profile", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + def set_profile_avatar_url(self, user_localpart, new_avatar_url): return self._simple_update_one( table="profiles", @@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore): updatevalues={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) + + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index cbec255966..04a0b59a39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +15,17 @@ # limitations under the License. from ._base import SQLBaseStore +from synapse.storage.appservice import ApplicationServiceWorkerStore +from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.push.baserules import list_with_base_rules +from synapse.api.constants import EventTypes from twisted.internet import defer +import abc import logging import simplejson as json @@ -47,8 +55,44 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks() +class PushRulesWorkerStore(ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -72,7 +116,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", @@ -88,6 +132,22 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -184,6 +244,18 @@ class PushRuleStore(SQLBaseStore): if uid in local_users_in_room: user_ids.add(uid) + forgotten = yield self.who_forgot_in_room( + event.room_id, on_invalidate=cache_context.invalidate, + ) + + for row in forgotten: + user_id = row["user_id"] + event_id = row["event_id"] + + mem_id = current_state_ids.get((EventTypes.Member, user_id), None) + if event_id == mem_id: + user_ids.discard(user_id) + rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) @@ -215,6 +287,8 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) + +class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, @@ -513,21 +587,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8cc9f0353b..307660b99a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ import types logger = logging.getLogger(__name__) -class PusherStore(SQLBaseStore): +class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: dataJson = r['data'] @@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed(([], [])) @@ -135,6 +133,48 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) + def get_all_updated_pushers_rows(self, last_id, current_id, limit): + """Get all the pushers that have changed between the given tokens. + + Returns: + Deferred(list(tuple)): each tuple consists of: + stream_id (str) + user_id (str) + app_id (str) + pushkey (str) + was_deleted (bool): whether the pusher was added/updated (False) + or deleted (True) + """ + + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_pushers_rows_txn(txn): + sql = ( + "SELECT id, user_name, app_id, pushkey" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + results = [list(row) + [False] for row in txn] + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + results.extend(list(row) + [True] for row in txn) + results.sort() # Sort so that they're ordered by stream id + + return results + return self.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator @@ -156,56 +196,74 @@ class PusherStore(SQLBaseStore): defer.returnValue(result) + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - def f(txn): - newly_inserted = self._simple_upsert_txn( - txn, - "pushers", - { - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, - { - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": encode_canonical_json(data), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - ) - if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + newly_inserted = yield self._simple_upsert( + table="pushers", + keyvalues={ + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) - yield self.runInteraction("add_pusher", f) + if newly_inserted: + self.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, (user_id,) + ) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) self._simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} ) - self._simple_upsert_txn( + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( txn, - "deleted_pushers", - {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id": stream_id}, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, ) with self._pushers_id_gen.get_next() as stream_id: @@ -268,9 +326,12 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry yield self._simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, - desc="set_throttle_params" + desc="set_throttle_params", + lock=False, ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f72d15f5ed..63997ed449 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,46 +15,50 @@ # limitations under the License. from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer +import abc import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): - def __init__(self, hs): - super(ReceiptsStore, self).__init__(hs) +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) - - if res and res.called and user_id in res.result: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -265,6 +270,59 @@ class ReceiptsStore(SQLBaseStore): } defer.returnValue(results) + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + if res: + if isinstance(res, defer.Deferred) and res.called: + res = res.result + if user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -313,10 +371,9 @@ class ReceiptsStore(SQLBaseStore): ) txn.execute(sql, (room_id, receipt_type, user_id)) - results = txn.fetchall() - if results and topological_ordering: - for to, so, _ in results: + if topological_ordering: + for to, so, _ in txn: if int(to) > topological_ordering: return False elif int(to) == topological_ordering and int(so) >= stream_ordering: @@ -351,6 +408,7 @@ class ReceiptsStore(SQLBaseStore): room_id=room_id, user_id=user_id, topological_ordering=topological_ordering, + stream_ordering=stream_ordering, ) return True @@ -452,25 +510,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26be6060c3..a50717db2d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -19,13 +19,75 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from six.moves import range -class RegistrationStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): - super(RegistrationStore, self).__init__(hs) +class RegistrationWorkerStore(SQLBaseStore): + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=["name", "password_hash", "is_guest"], + allow_none=True, + desc="get_user_by_id", + ) + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. + """ + return self.runInteraction( + "get_user_by_access_token", + self._query_for_auth, + token + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + defer.returnValue(res if res else False) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + +class RegistrationStore(RegistrationWorkerStore, + background_updates.BackgroundUpdateStore): + + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() @@ -36,12 +98,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id"], ) - self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], - ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.register_noop_background_update("refresh_tokens_device_index") @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): @@ -177,9 +237,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -187,18 +249,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={ - "name": user_id, - }, - retcols=["name", "password_hash", "is_guest"], - allow_none=True, - desc="get_user_by_id", - ) - def get_users_by_id_case_insensitive(self, user_id): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. @@ -209,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): " WHERE lower(name) = lower(?)" ) txn.execute(sql, (user_id,)) - return dict(txn.fetchall()) + return dict(txn) return self.runInteraction("get_users_by_id_case_insensitive", f) @@ -236,12 +286,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): "user_set_password_hash", user_set_password_hash_txn ) - @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +298,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens """ def f(txn): keyvalues = { @@ -262,13 +309,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +317,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for row in rows: + for token, _, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,7 +332,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) - yield self.runInteraction( + return tokens_and_devices + + return self.runInteraction( "user_delete_access_tokens", f, ) @@ -312,34 +354,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): return self.runInteraction("delete_access_token", f) - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. - """ - return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - defer.returnValue(res if res else False) - @cachedInlineCallbacks() def is_guest(self, user_id): res = yield self._simple_select_one_onecol( @@ -352,22 +366,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(res if res else False) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", { @@ -438,6 +436,19 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(ret) @defer.inlineCallbacks + def count_nonbridged_users(self): + def _count_users(txn): + txn.execute(""" + SELECT COALESCE(COUNT(*), 0) FROM users + WHERE appservice_id IS NULL + """) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ Gets the localpart of the next generated user ID. @@ -451,18 +462,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): """ def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - rows = self.cursor_to_dict(txn) regex = re.compile("^@(\d+):") found = set() - for r in rows: - user_id = r["name"] + for user_id, in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) - for i in xrange(len(found) + 1): + for i in range(len(found) + 1): if i not in found: return i diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8a2fe2fdf5..ea6a189185 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -16,14 +16,14 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore -from .engines import PostgresEngine, Sqlite3Engine +from synapse.storage._base import SQLBaseStore +from synapse.storage.search import SearchStore +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks import collections import logging -import ujson as json +import simplejson as json +import re logger = logging.getLogger(__name__) @@ -33,8 +33,144 @@ OpsLevel = collections.namedtuple( ("ban_level", "kick_level", "redact_level",) ) +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", + ("messages_per_second", "burst_count",) +) + -class RoomStore(SQLBaseStore): +class RoomWorkerStore(SQLBaseStore): + def get_public_room_ids(self): + return self._simple_select_onecol( + table="rooms", + keyvalues={ + "is_public": True, + }, + retcol="room_id", + desc="get_public_room_ids", + ) + + @cached(num_args=2, max_entries=100) + def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): + """Get pulbic rooms for a particular list, or across all lists. + + Args: + stream_id (int) + network_tuple (ThirdPartyInstanceID): The list to use (None, None) + means the main list, None means all lsits. + """ + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, + stream_id, network_tuple=network_tuple + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, + network_tuple): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn( + txn, stream_id, network_tuple=network_tuple + ).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): + if network_tuple: + # We want to get from a particular list. No aggregation required. + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? %s + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + if network_tuple.appservice_id is not None: + txn.execute( + sql % ("AND appservice_id = ? AND network_id = ?",), + (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + ) + else: + txn.execute( + sql % ("AND appservice_id IS NULL",), + (stream_id,) + ) + return dict(txn) + else: + # We want to get from all lists, so we need to aggregate the results + + logger.info("Executing full list") + + sql = (""" + SELECT room_id, visibility + FROM public_room_list_stream + INNER JOIN ( + SELECT + room_id, max(stream_id) AS stream_id, appservice_id, + network_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id, appservice_id, network_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute( + sql, + (stream_id,) + ) + + results = {} + # A room is visible if its visible on any list. + for room_id, visibility in txn: + results[room_id] = bool(visibility) or results.get(room_id, False) + + return results + + def get_public_room_changes(self, prev_stream_id, new_stream_id, + network_tuple): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn( + txn, prev_stream_id, network_tuple + ) + + now_rooms_dict = self.get_published_at_stream_id_txn( + txn, new_stream_id, network_tuple + ) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis + ) + + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms + + return newly_visible, newly_unpublished + + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self._simple_select_one_onecol( + table="blocked_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + +class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): @@ -221,16 +357,6 @@ class RoomStore(SQLBaseStore): ) self.hs.get_notifier().on_new_replication_data() - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={ - "is_public": True, - }, - retcol="room_id", - desc="get_public_room_ids", - ) - def get_room_count(self): """Retrieve a list of all rooms """ @@ -257,8 +383,8 @@ class RoomStore(SQLBaseStore): }, ) - self._store_event_search_txn( - txn, event, "content.topic", event.content["topic"] + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"], ) def _store_room_name_txn(self, txn, event): @@ -273,14 +399,14 @@ class RoomStore(SQLBaseStore): } ) - self._store_event_search_txn( - txn, event, "content.name", event.content["name"] + self.store_event_search_txn( + txn, event, "content.name", event.content["name"], ) def _store_room_message_txn(self, txn, event): if hasattr(event, "content") and "body" in event.content: - self._store_event_search_txn( - txn, event, "content.body", event.content["body"] + self.store_event_search_txn( + txn, event, "content.body", event.content["body"], ) def _store_history_visibility_txn(self, txn, event): @@ -302,31 +428,6 @@ class RoomStore(SQLBaseStore): event.content[key] )) - def _store_event_search_txn(self, txn, event, key, value): - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - txn.execute( - sql, - ( - event.event_id, event.room_id, key, value, - event.internal_metadata.stream_ordering, - event.origin_server_ts, - ) - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - txn.execute(sql, (event.event_id, event.room_id, key, value,)) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - def add_event_report(self, room_id, event_id, user_id, reason, content, received_ts): next_id = self._event_reports_id_gen.get_next() @@ -347,129 +448,180 @@ class RoomStore(SQLBaseStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = (""" + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """) + + txn.execute(sql, (prev_id, current_id, limit,)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) - Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. - """ return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple + "get_all_new_public_rooms", get_all_new_public_rooms ) - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. + Args: + user_id (str) - sql = (""" - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """) + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) - ) - else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) - return dict(txn.fetchall()) + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) else: - # We want to get from all lists, so we need to aggregate the results + defer.returnValue(None) - logger.info("Executing full list") - - sql = (""" - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """) - - txn.execute( - sql, - (stream_id,) - ) - - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn.fetchall(): - results[room_id] = bool(visibility) or results.get(room_id, False) - - return results + @defer.inlineCallbacks + def block_room(self, room_id, user_id): + yield self._simple_insert( + table="blocked_rooms", + values={ + "room_id": room_id, + "user_id": user_id, + }, + desc="block_room", + ) + yield self.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, (room_id,), + ) - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) + Args: + room_id (str) - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 + + # Now update all the tables to set the quarantined_by flag + + txn.executemany(""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, ((quarantined_by, media_id) for media_id in local_mxcs)) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs + ) ) - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) - return newly_visible, newly_unpublished + return total_media_quarantined return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn + "quarantine_media_in_room", + _quarantine_media_in_room_txn, ) - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = (""" - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """) - - txn.execute(sql, (prev_id, current_id, limit,)) - return txn.fetchall() + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room - if prev_id == current_id: - return defer.succeed([]) + Args: + txn (cursor) + room_id (str) - return self.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 545d3d3a99..6a861943a2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +18,17 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.util.async import Linearizer +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.stringutils import to_ascii from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -34,112 +38,47 @@ RoomsForUser = namedtuple( ("room_id", "sender", "membership", "event_id", "stream_ordering") ) +GetRoomsForUserWithStreamOrdering = namedtuple( + "_GetRoomsForUserWithStreamOrdering", + ("room_id", "stream_ordering",) +) -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" - - -class RoomMemberStore(SQLBaseStore): - def __init__(self, hs): - super(RoomMemberStore, self).__init__(hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ] - ) +# We store this using a namedtuple so that we save about 3x space over using a +# dict. +ProfileInfo = namedtuple( + "ProfileInfo", ("avatar_url", "display_name") +) - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. - # The only current event that can also be an outlier is if its an - # invite that has come in across federation. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_invite_from_remote() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - } - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" +class RoomMemberWorkerStore(EventsWorkerStore): + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) + def get_hosts_in_room(self, room_id, cache_context): + """Returns the set of all hosts currently in the room + """ + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate, ) + hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) + defer.returnValue(hosts) - def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - @cached(max_entries=500000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): - - rows = self._get_members_rows_txn( - txn, - room_id=room_id, - membership=Membership.JOIN, + sql = ( + "SELECT m.user_id FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" ) - return [r["user_id"] for r in rows] + txn.execute(sql, (room_id, Membership.JOIN,)) + return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) @cached() @@ -246,57 +185,382 @@ class RoomMemberStore(SQLBaseStore): return results - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): - where_clause = "c.room_id = ?" - where_values = [room_id] - - if membership: - where_clause += " AND m.membership = ?" - where_values.append(membership) + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to - if user_id: - where_clause += " AND m.user_id = ?" - where_values.append(user_id) + Args: + user_id (str) - sql = ( - "SELECT m.* FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " - " AND m.user_id = c.state_key" - " WHERE c.type = 'm.room.member' AND %(where)s" - ) % { - "where": where_clause, - } - - txn.execute(sql, where_values) - rows = self.cursor_to_dict(txn) - - return rows - - @cached(max_entries=500000, iterable=True) - def get_rooms_for_user(self, user_id): - return self.get_rooms_for_user_where_membership_is( + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + defer.returnValue(frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + )) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate, + ) + defer.returnValue(frozenset(r.room_id for r in rooms)) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): """Returns the set of users who share a room with `user_id` """ - rooms = yield self.get_rooms_for_user( + room_ids = yield self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate, ) user_who_share_room = set() - for room in rooms: + for room_id in room_ids: user_ids = yield self.get_users_in_room( - room.room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate, ) user_who_share_room.update(user_ids) defer.returnValue(user_who_share_room) + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + event.room_id, state_group, context.current_state_ids, + event=event, + context=context, + ) + + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry, + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context, event=None, context=None): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in context.delta_ids.iteritems() + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, + allow_rejected=False, + update_metrics=False, + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=missing_member_event_ids, + retcols=('user_id', 'display_name', 'avatar_url',), + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", + ) + + users_in_room.update({ + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + }) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[to_ascii(event.state_key)] = ProfileInfo( + display_name=to_ascii(event.content.get("displayname", None)), + avatar_url=to_ascii(event.content.get("avatar_url", None)), + ) + + defer.returnValue(users_in_room) + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry, + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + defer.returnValue(joined_hosts) + + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cached() + def who_forgot_in_room(self, room_id): + return self._simple_select_list( + table="room_memberships", + retcols=("user_id", "event_id"), + keyvalues={ + "room_id": room_id, + "forgotten": 1, + }, + desc="who_forgot" + ) + + +class RoomMemberStore(RoomMemberWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ] + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, event.internal_metadata.stream_ordering + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -368,124 +632,6 @@ class RoomMemberStore(SQLBaseStore): forgot = yield self.runInteraction("did_forget_membership_at", f) defer.returnValue(forgot == 1) - @cached() - def who_forgot_in_room(self, room_id): - return self._simple_select_list( - table="room_memberships", - retcols=("user_id", "event_id"), - keyvalues={ - "room_id": room_id, - "forgotten": 1, - }, - desc="who_forgot" - ) - - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, event=event, - ) - - def get_joined_users_from_state(self, room_id, state_group, state_ids): - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_users_from_context( - room_id, state_group, state_ids, - ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=100000) - def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context, event=None): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - member_event_ids = [ - e_id - for key, e_id in current_state_ids.iteritems() - if key[0] == EventTypes.Member - ] - - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=['user_id', 'display_name', 'avatar_url'], - keyvalues={ - "membership": Membership.JOIN, - }, - batch_size=500, - desc="_get_joined_users_from_context", - ) - - users_in_room = { - row["user_id"]: { - "display_name": row["display_name"], - "avatar_url": row["avatar_url"], - } - for row in rows - } - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[event.state_key] = { - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - - defer.returnValue(users_in_room) - - def is_host_joined(self, room_id, host, state_group, state_ids): - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._is_host_joined( - room_id, host, state_group, state_ids - ) - - @cachedInlineCallbacks(num_args=3) - def _is_host_joined(self, room_id, host, state_group, current_state_ids): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - for (etype, state_key), event_id in current_state_ids.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) - - defer.returnValue(False) - @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( @@ -499,8 +645,9 @@ class RoomMemberStore(SQLBaseStore): def add_membership_profile_txn(txn): sql = (""" - SELECT stream_ordering, event_id, events.room_id, content + SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events + INNER JOIN event_json USING (event_id) INNER JOIN room_memberships USING (event_id) WHERE ? <= stream_ordering AND stream_ordering < ? AND type = 'm.room.member' @@ -521,8 +668,9 @@ class RoomMemberStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json['content'] + except Exception: continue display_name = content.get("displayname", None) @@ -560,3 +708,71 @@ class RoomMemberStore(SQLBaseStore): yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) defer.returnValue(result) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry, + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 8755bb2e49..4d725b92fe 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging +import simplejson as json + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 4269ac69ad..e7351c3ae6 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -17,7 +17,7 @@ import logging from synapse.storage.prepare_database import get_statements from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index 71b12a2731..6df57b5206 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -16,7 +16,7 @@ import logging from synapse.storage.prepare_database import get_statements -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 5b7d8d1ab5..85bd1a2006 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -14,6 +14,8 @@ import logging from synapse.config.appservice import load_appservices +from six.moves import range + logger = logging.getLogger(__name__) @@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except: + except Exception: # Maybe we already added the column? Hope so... pass @@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py index 470ae0c005..fe6b7d196d 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/schema/delta/31/search_update.py @@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py index 83066cccc9..1e002f9db2 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -15,7 +15,7 @@ from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py index 784f3b348f..20ad8bd5a6 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref; -- and is used incredibly rarely. DROP INDEX IF EXISTS events_order_topo_stream_room; +-- an equivalent index to this actually gets re-created in delta 41, because it +-- turned out that deleting it wasn't a great plan :/. In any case, let's +-- delete it here, and delta 41 will create a new one with an added UNIQUE +-- constraint DROP INDEX IF EXISTS event_search_ev_idx; """ diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql index f090a7b75a..515e6b8e84 100644 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql @@ -13,5 +13,7 @@ * limitations under the License. */ - INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_postgres_gist', '{}'); +-- We no longer do this given we back it out again in schema 47 + +-- INSERT into background_updates (update_name, progress_json) +-- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql index 34db0cf12b..b7bee8b692 100644 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql @@ -1,4 +1,4 @@ -/* Copyright 2015, 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,5 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql new file mode 100644 index 0000000000..5d9cfecf36 --- /dev/null +++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql new file mode 100644 index 0000000000..a194bf0238 --- /dev/null +++ b/synapse/storage/schema/delta/41/ratelimit.sql @@ -0,0 +1,22 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE ratelimit_override ( + user_id TEXT NOT NULL, + messages_per_second BIGINT, + burst_count BIGINT +); + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql index bb225dafbf..b8821ac759 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,4 +14,4 @@ */ INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..ea6a18196d --- /dev/null +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql new file mode 100644 index 0000000000..0e3cd143ff --- /dev/null +++ b/synapse/storage/schema/delta/43/blocked_rooms.sql @@ -0,0 +1,21 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE blocked_rooms ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL -- Admin who blocked the room +); + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql new file mode 100644 index 0000000000..630907ec4f --- /dev/null +++ b/synapse/storage/schema/delta/43/quarantine_media.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; +ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/43/url_cache.sql index 290bd6da86..45ebe020da 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/43/url_cache.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql new file mode 100644 index 0000000000..ee7062abe4 --- /dev/null +++ b/synapse/storage/schema/delta/43/user_share.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table keeping track of who shares a room with who. We only keep track +-- of this for local users, so `user_id` is local users only (but we do keep track +-- of which remote users share a room) +CREATE TABLE users_who_share_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room +); + + +CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); +CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); +CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); + + +-- Make sure that we populate the table initially +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/schema/delta/45/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql new file mode 100644 index 0000000000..68c48a89a9 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql new file mode 100644 index 0000000000..f505fb22b5 --- /dev/null +++ b/synapse/storage/schema/delta/47/last_access_media.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql new file mode 100644 index 0000000000..31d7a817eb --- /dev/null +++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql new file mode 100644 index 0000000000..edccf4a96f --- /dev/null +++ b/synapse/storage/schema/delta/47/push_actions_staging.sql @@ -0,0 +1,28 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Temporary staging area for push actions that have been calculated for an +-- event, but the event hasn't yet been persisted. +-- When the event is persisted the rows are moved over to the +-- event_push_actions table. +CREATE TABLE event_push_actions_staging ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL +); + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..f6766501d2 --- /dev/null +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -0,0 +1,37 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute( + "CREATE SEQUENCE state_group_id_seq START WITH %s", + (start_val, ), + ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..2233af87d7 --- /dev/null +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,57 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute(""" + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + cur.execute(""" + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 8f2b3c4435..6ba3e59889 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -13,28 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import namedtuple +import logging +import re +import simplejson as json + from twisted.internet import defer from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import logging -import re -import ujson as json - logger = logging.getLogger(__name__) +SearchEntry = namedtuple('SearchEntry', [ + 'key', 'value', 'event_id', 'room_id', 'stream_ordering', + 'origin_server_ts', +]) + class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" + EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, hs): - super(SearchStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -42,23 +49,35 @@ class SearchStore(BackgroundUpdateStore): self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) - self.register_background_update_handler( + + # we used to have a background update to turn the GIN index into a + # GIST one; we no longer do that (obviously) because we actually want + # a GIN index. However, it's possible that some people might still have + # the background update queued, so we register a handler to clear the + # background update. + self.register_noop_background_update( self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, - self._background_reindex_gist_search + ) + + self.register_background_update_handler( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, + self._background_reindex_gin_search ) @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): + # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id, room_id, type, content FROM events" + "SELECT stream_ordering, event_id, room_id, type, json, " + " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -67,6 +86,10 @@ class SearchStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + # we could stream straight from the results into + # store_search_entries_txn with a generator function, but that + # would mean having two cursors open on the database at once. + # Instead we just build a list of results. rows = self.cursor_to_dict(txn) if not rows: return 0 @@ -79,9 +102,12 @@ class SearchStore(BackgroundUpdateStore): event_id = row["event_id"] room_id = row["room_id"] etype = row["type"] + stream_ordering = row["stream_ordering"] + origin_server_ts = row["origin_server_ts"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: continue if etype == "m.room.message": @@ -93,6 +119,8 @@ class SearchStore(BackgroundUpdateStore): elif etype == "m.room.name": key = "content.name" value = content["name"] + else: + raise Exception("unexpected event type %s" % etype) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -103,25 +131,16 @@ class SearchStore(BackgroundUpdateStore): # then skip over it continue - event_search_rows.append((event_id, room_id, key, value)) + event_search_rows.append(SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + )) - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, vector)" - " VALUES (?,?,?,to_tsvector('english', ?))" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): - clump = event_search_rows[index:index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + self.store_search_entries_txn(txn, event_search_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -145,25 +164,48 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(result) @defer.inlineCallbacks - def _background_reindex_gist_search(self, progress, batch_size): + def _background_reindex_gin_search(self, progress, batch_size): + """This handles old synapses which used GIST indexes, if any; + converting them back to be GIN as per the actual schema. + """ + def create_index(conn): conn.rollback() - conn.set_session(autocommit=True) - c = conn.cursor() - c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist" - " ON event_search USING GIST (vector)" - ) + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + conn.set_session(autocommit=True) - c.execute("DROP INDEX event_search_fts_idx") + try: + c = conn.cursor() - conn.set_session(autocommit=False) + # if we skipped the conversion to GIST, we may already/still + # have an event_search_fts_idx; unfortunately postgres 9.4 + # doesn't support CREATE INDEX IF EXISTS so we just catch the + # exception and ignore it. + import psycopg2 + try: + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx" + " ON event_search USING GIN (vector)" + ) + except psycopg2.ProgrammingError as e: + logger.warn( + "Ignoring error %r when trying to switch from GIST to GIN", + e + ) + + # we should now be able to delete the GIST index. + c.execute( + "DROP INDEX IF EXISTS event_search_fts_idx_gist" + ) + finally: + conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): yield self.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) defer.returnValue(1) @defer.inlineCallbacks @@ -242,6 +284,85 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(num_rows) + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store_search_entries_txn( + txn, + (SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ),), + ) + + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + entry.stream_ordering, entry.origin_server_ts, + ) for entry in entries) + + # inserts to a GIN index are normally batched up into a pending + # list, and then all committed together once the list gets to a + # certain size. The trouble with that is that postgres (pre-9.5) + # uses work_mem to determine the length of the list, and work_mem + # is typically very large. + # + # We therefore reduce work_mem while we do the insert. + # + # (postgres 9.5 uses the separate gin_pending_list_limit setting, + # so doesn't suffer the same problem, but changing work_mem will + # be harmless) + # + # Note that we don't need to worry about restoring it on + # exception, because exceptions will cause the transaction to be + # rolled back, including the effects of the SET command. + # + # Also: we use SET rather than SET LOCAL because there's lots of + # other stuff going on in this transaction, which want to have the + # normal work_mem setting. + + txn.execute("SET work_mem='256kB'") + txn.executemany(sql, args) + txn.execute("RESET work_mem") + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + ) for entry in entries) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. @@ -407,7 +528,7 @@ class SearchStore(BackgroundUpdateStore): origin_server_ts, stream = pagination_token.split(",") origin_server_ts = int(origin_server_ts) stream = int(stream) - except: + except Exception: raise SynapseError(400, "Invalid pagination token") clauses.append( diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index e1dca927d7..9e6eaaa532 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash from synapse.util.caches.descriptors import cached, cachedList -class SignatureStore(SQLBaseStore): - """Persistence for event signatures and hashes""" - +class SignatureWorkerStore(SQLBaseStore): @cached() def get_event_reference_hash(self, event_id): - return self._get_event_reference_hashes_txn(event_id) + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() @cachedList(cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1) @@ -72,7 +72,11 @@ class SignatureStore(SQLBaseStore): " WHERE event_id = ?" ) txn.execute(query, (event_id, )) - return {k: v for k, v in txn.fetchall()} + return {k: v for k, v in txn} + + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 84482d8285..ffa4246031 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,14 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.storage.engines import PostgresEngine +from collections import namedtuple +import logging from twisted.internet import defer -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -28,45 +32,97 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, hs): - super(StateStore, self).__init__(hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,) + ) + + return { + (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn + } + + return self.runInteraction( + "get_current_state_ids", + _get_current_state_ids_txn, ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return _GetStateGroupDelta(prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + }) + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, ) @defer.inlineCallbacks @@ -78,12 +134,26 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the state IDs for the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + defer.returnValue(group_to_state[state_group]) + + @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids @@ -96,156 +166,21 @@ class StateStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.values() - for ev_id in group_ids.values() + ev_id for group_ids in group_to_ids.itervalues() + for ev_id in group_ids.itervalues() ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.values() if v in state_event_map + state_event_map[v] for v in event_id_map.itervalues() + if v in state_event_map ] - for group, event_id_map in group_to_ids.items() + for group, event_id_map in group_to_ids.iteritems() }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.items() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.items() - ], - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.items() - ], - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - @cached(num_args=2, max_entries=100000, iterable=True) - def _get_state_group_from_group(self, group, types): - raise NotImplementedError() - - @cachedList(cached_method_name="_get_state_group_from_group", - list_name="groups", num_args=2, inlineCallbacks=True) + @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -305,6 +240,9 @@ class StateStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -319,15 +257,24 @@ class StateStore(SQLBaseStore): args.extend(where_args) txn.execute(sql % (where_clause,), args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.extend(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -344,23 +291,30 @@ class StateStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" " WHERE state_group = ? %s" % (where_clause,), args ) - rows = txn.fetchall() - results[group].update({ - (typ, state_key): event_id - for typ, state_key, event_id in rows + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn if (typ, state_key) not in results[group] - }) + ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( @@ -393,21 +347,21 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].items() + for k, v in group_to_state[group].iteritems() if v in state_event_map } - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -430,12 +384,12 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -474,8 +428,8 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_ids_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, max_entries=10000) - def _get_state_group_for_event(self, room_id, event_id): + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", keyvalues={ @@ -517,20 +471,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -544,10 +500,10 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { - k: v for k, v in state_dict_ids.items() + k: v for k, v in state_dict_ids.iteritems() if include(k[0], k[1]) }, missing_types, got_all @@ -561,7 +517,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -578,7 +534,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -606,46 +562,247 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. - for group, group_state_dict in group_to_state_dict.items(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] - - state_dict.update({ - (intern_string(k[0]), intern_string(k[1])): v - for k, v in group_state_dict.items() - }) + for group, group_state_dict in group_to_state_dict.iteritems(): + state_dict = results[group] + + state_dict.update( + ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) + for k, v in group_state_dict.iteritems() + ) self._state_group_cache.update( cache_seq_num, key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.items(): - results[group] = { - key: event_id - for key, event_id in state_dict.items() - if event_id - } - defer.returnValue(results) - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": state_group, + "room_id": room_id, + "event_id": event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_ids.iteritems() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in current_state_ids.iteritems() + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_state_ids), + full=True, + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in state_groups.iteritems() + ], + ) + + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): @@ -727,7 +884,7 @@ class StateStore(SQLBaseStore): # of keys delta_state = { - key: value for key, value in curr_state.items() + key: value for key, value in curr_state.iteritems() if prev_state.get(key, None) != value } @@ -767,7 +924,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.items() + for key, state_id in delta_state.iteritems() ], ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 200d124632..f0784ba137 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,15 +35,20 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore + from synapse.util.caches.descriptors import cached -from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.storage.engines import PostgresEngine, Sqlite3Engine +import abc import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -143,81 +148,41 @@ def filter_to_clause(event_filter): return " AND ".join(clauses), args -class StreamStore(SQLBaseStore): - @defer.inlineCallbacks - def get_appservice_room_stream(self, service, from_key, to_key, limit=0): - # NB this lives here instead of appservice.py so we can reuse the - # 'private' StreamToken class in this file. - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - defer.returnValue(([], to_key)) - return - - # select all the events between from/to with a sensible limit - sql = ( - "SELECT e.event_id, e.room_id, e.type, s.state_key, " - "e.stream_ordering FROM events AS e " - "LEFT JOIN state_events as s ON " - "e.event_id = s.event_id " - "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "limit": limit - } - - def f(txn): - # pull out all the events between the tokens - txn.execute(sql, (from_id.stream, to_id.stream,)) - rows = self.cursor_to_dict(txn) - - # Logic: - # - We want ALL events which match the AS room_id regex - # - We want ALL events which match the rooms represented by the AS - # room_alias regex - # - We want ALL events for rooms that AS users have joined. - # This is currently supported via get_app_service_rooms (which is - # used for the Notifier listener rooms). We can't reasonably make a - # SQL query for these room IDs, so we'll pull all the events between - # from/to and filter in python. - rooms_for_as = self._get_app_service_rooms_txn(txn, service) - room_ids_for_as = [r.room_id for r in rooms_for_as] - - def app_service_interested(row): - if row["room_id"] in room_ids_for_as: - return True +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_room_max_stream_ordering` and `get_room_min_stream_ordering` + which can be called in the initializer. + """ - if row["type"] == EventTypes.Member: - if service.is_interested_in_user(row.get("state_key")): - return True - return False + __metaclass__ = abc.ABCMeta - return [r for r in rows if app_service_interested(r)] + def __init__(self, db_conn, hs): + super(StreamWorkerStore, self).__init__(db_conn, hs) - rows = yield self.runInteraction("get_appservice_room_stream", f) - - ret = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._stream_order_on_start = self.get_room_max_stream_ordering() - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + @abc.abstractmethod + def get_room_max_stream_ordering(self): + raise NotImplementedError() - defer.returnValue((ret, key)) + @abc.abstractmethod + def get_room_min_stream_ordering(self): + raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, @@ -233,13 +198,14 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.get_room_events_stream_for_room)( + for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): + res = yield make_deferred_yieldable(defer.gatherResults([ + run_in_background( + self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ])) + ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) @@ -381,88 +347,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, event_filter=None): - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == 'b': - order = "DESC" - bounds = upper_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, lower_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - else: - order = "ASC" - bounds = lower_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, upper_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - if int(limit) > 0: - args.append(int(limit)) - limit_str = " LIMIT ?" - else: - limit_str = "" - - sql = ( - "SELECT * FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s %(limit)s" - ) % { - "bounds": bounds, - "order": order, - "limit": limit_str - } - - def f(txn): - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - if rows: - topo = rows[-1]["topological_ordering"] - toke = rows[-1]["stream_ordering"] - if direction == 'b': - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = str(RoomStreamToken(topo, toke)) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_key if to_key else from_key - - return rows, next_token, - - rows, token = yield self.runInteraction("paginate_room_events", f) - - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - defer.returnValue((events, token)) - - @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token, from_token @@ -534,6 +418,33 @@ class StreamStore(SQLBaseStore): "get_recent_events_for_room", get_recent_events_for_room_txn ) + def get_room_event_after_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or after a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering, )) + return txn.fetchone() + + return self.runInteraction( + "get_room_event_after_stream_ordering", _f, + ) + @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. @@ -542,7 +453,7 @@ class StreamStore(SQLBaseStore): `room_id` causes it to return the current room specific topological token. """ - token = yield self._stream_id_gen.get_current_token() + token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: @@ -552,12 +463,6 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() - def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -829,3 +734,96 @@ class StreamStore(SQLBaseStore): updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() + + @defer.inlineCallbacks + def paginate_room_events(self, room_id, from_key, to_key=None, + direction='b', limit=-1, event_filter=None): + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == 'b': + order = "DESC" + bounds = upper_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, lower_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + else: + order = "ASC" + bounds = lower_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, upper_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + if int(limit) > 0: + args.append(int(limit)) + limit_str = " LIMIT ?" + else: + limit_str = "" + + sql = ( + "SELECT * FROM events" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s %(limit)s" + ) % { + "bounds": bounds, + "order": order, + "limit": limit_str + } + + def f(txn): + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + + if rows: + topo = rows[-1]["topological_ordering"] + toke = rows[-1]["stream_ordering"] + if direction == 'b': + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = str(RoomStreamToken(topo, toke)) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_key if to_key else from_key + + return rows, next_token, + + rows, token = yield self.runInteraction("paginate_room_events", f) + + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 5a2c1aa59b..6671d3cfca 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,25 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage.account_data import AccountDataWorkerStore + from synapse.util.caches.descriptors import cached from twisted.internet import defer -import ujson as json +import simplejson as json import logging -logger = logging.getLogger(__name__) +from six.moves import range +logger = logging.getLogger(__name__) -class TagsStore(SQLBaseStore): - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() +class TagsWorkerStore(AccountDataWorkerStore): @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -95,7 +91,7 @@ class TagsStore(SQLBaseStore): for stream_id, user_id, room_id in tag_ids: txn.execute(sql, (user_id, room_id)) tags = [] - for tag, content in txn.fetchall(): + for tag, content in txn: tags.append(json.dumps(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, user_id, room_id, tag_json)) @@ -104,7 +100,7 @@ class TagsStore(SQLBaseStore): batch_size = 50 results = [] - for i in xrange(0, len(tag_ids), batch_size): + for i in range(0, len(tag_ids), batch_size): tags = yield self.runInteraction( "get_all_updated_tag_content", get_tag_content, @@ -132,7 +128,7 @@ class TagsStore(SQLBaseStore): " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) - room_ids = [row[0] for row in txn.fetchall()] + room_ids = [row[0] for row in txn] return room_ids changed = self._account_data_stream_cache.has_entity_changed( @@ -170,6 +166,8 @@ class TagsStore(SQLBaseStore): row["tag"]: json.loads(row["content"]) for row in rows }) + +class TagsStore(TagsWorkerStore): @defer.inlineCallbacks def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 809fdd311f..f825264ea9 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json from collections import namedtuple import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, hs): - super(TransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(TransactionStore, self).__init__(db_conn, hs) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py new file mode 100644 index 0000000000..d6e289ffbe --- /dev/null +++ b/synapse/storage/user_directory.py @@ -0,0 +1,764 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore + +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id + +import re +import logging + +logger = logging.getLogger(__name__) + + +class UserDirectoryStore(SQLBaseStore): + @cachedInlineCallbacks(cache_context=True) + def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): + """Check if the room is either world_readable or publically joinable + """ + current_state_ids = yield self.get_current_state_ids( + room_id, on_invalidate=cache_context.invalidate + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + defer.returnValue(True) + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks + def add_users_to_public_room(self, room_id, user_ids): + """Add user to the list of users in public rooms + + Args: + room_id (str): A room_id that all users are in that is world_readable + or publically joinable + user_ids (list(str)): Users to add + """ + yield self._simple_insert_many( + table="users_in_public_rooms", + values=[ + { + "user_id": user_id, + "room_id": room_id, + } + for user_id in user_ids + ], + desc="add_users_to_public_room" + ) + for user_id in user_ids: + self.get_user_in_public_room.invalidate((user_id,)) + + def add_profiles_to_user_dir(self, room_id, users_with_profile): + """Add profiles to the user directory + + Args: + room_id (str): A room_id that all users are joined to + users_with_profile (dict): Users to add to directory in the form of + mapping of user_id -> ProfileInfo + """ + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + args = ( + ( + user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), + profile.display_name, + ) + for user_id, profile in users_with_profile.iteritems() + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT INTO user_directory_search(user_id, value) + VALUES (?,?) + """ + args = ( + ( + user_id, + "%s %s" % (user_id, p.display_name,) if p.display_name else user_id + ) + for user_id, p in users_with_profile.iteritems() + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + def _add_profiles_to_user_dir_txn(txn): + txn.executemany(sql, args) + self._simple_insert_many_txn( + txn, + table="user_directory", + values=[ + { + "user_id": user_id, + "room_id": room_id, + "display_name": profile.display_name, + "avatar_url": profile.avatar_url, + } + for user_id, profile in users_with_profile.iteritems() + ] + ) + for user_id in users_with_profile: + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + + return self.runInteraction( + "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_user_dir(self, user_id, room_id): + yield self._simple_update_one( + table="user_directory", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_user_dir", + ) + self.get_user_in_directory.invalidate((user_id,)) + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id): + def _update_profile_in_user_dir_txn(txn): + new_entry = self._simple_upsert_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + insertion_values={"room_id": room_id}, + values={"display_name": display_name, "avatar_url": avatar_url}, + lock=False, # We're only inserter + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the localpart most highly, then display name and finally + # server name + if new_entry: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + else: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), get_domain_from_id(user_id), + display_name, user_id, + ) + ) + elif isinstance(self.database_engine, Sqlite3Engine): + value = "%s %s" % (user_id, display_name,) if display_name else user_id + self._simple_upsert_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + values={"value": value}, + lock=False, # We're only inserter + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_public_user_list(self, user_id, room_id): + yield self._simple_update_one( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_public_user_list", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + ) + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + txn.call_after( + self.get_user_in_public_room.invalidate, (user_id,) + ) + return self.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn, + ) + + @defer.inlineCallbacks + def remove_from_user_in_public_room(self, user_id): + yield self._simple_delete( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + desc="remove_from_user_in_public_room", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def get_users_in_public_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_public_due_to_room", + ) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + user_ids_dir = yield self._simple_select_onecol( + table="user_directory", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_pub = yield self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share = yield self._simple_select_onecol( + table="users_who_share_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_dir) + user_ids.update(user_ids_pub) + user_ids.update(user_ids_share) + + defer.returnValue(user_ids) + + @defer.inlineCallbacks + def get_all_rooms(self): + """Get all room_ids we've ever known about, in ascending order of "size" + """ + sql = """ + SELECT room_id FROM current_state_events + GROUP BY room_id + ORDER BY count(*) ASC + """ + rows = yield self._execute("get_all_rooms", None, sql) + defer.returnValue([room_id for room_id, in rows]) + + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): + """Insert entries into the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _add_users_who_share_room_txn(txn): + self._simple_insert_many_txn( + txn, + table="users_who_share_rooms", + values=[ + { + "user_id": user_id, + "other_user_id": other_user_id, + "room_id": room_id, + "share_private": share_private, + } + for user_id, other_user_id in user_id_tuples + ], + ) + for user_id, other_user_id in user_id_tuples: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "add_users_who_share_room", _add_users_who_share_room_txn + ) + + def update_users_who_share_room(self, room_id, share_private, user_id_sets): + """Updates entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _update_users_who_share_room_txn(txn): + sql = """ + UPDATE users_who_share_rooms + SET room_id = ?, share_private = ? + WHERE user_id = ? AND other_user_id = ? + """ + txn.executemany( + sql, + ( + (room_id, share_private, uid, oid) + for uid, oid in user_id_sets + ) + ) + for user_id, other_user_id in user_id_sets: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "update_users_who_share_room", _update_users_who_share_room_txn + ) + + def remove_user_who_share_room(self, user_id, other_user_id): + """Deletes entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _remove_user_who_share_room_txn(txn): + self._simple_delete_txn( + txn, + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + ) + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + + return self.runInteraction( + "remove_user_who_share_room", _remove_user_who_share_room_txn + ) + + @cached(max_entries=500000) + def get_if_users_share_a_room(self, user_id, other_user_id): + """Gets if users share a room. + + Args: + user_id (str): Must be a local user_id + other_user_id (str) + + Returns: + bool|None: None if they don't share a room, otherwise whether they + share a private room or not. + """ + return self._simple_select_one_onecol( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + retcol="share_private", + allow_none=True, + desc="get_if_users_share_a_room", + ) + + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_users_who_share_room_from_dir(self, user_id): + """Returns the set of users who share a room with `user_id` + + Args: + user_id(str): Must be a local user + + Returns: + dict: user_id -> share_private mapping + """ + rows = yield self._simple_select_list( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + }, + retcols=("other_user_id", "share_private",), + desc="get_users_who_share_room_with_user", + ) + + defer.returnValue({ + row["other_user_id"]: row["share_private"] + for row in rows + }) + + def get_users_in_share_dir_with_room_id(self, user_id, room_id): + """Get all user tuples that are in the users_who_share_rooms due to the + given room_id. + + Returns: + [(user_id, other_user_id)]: where one of the two will match the given + user_id. + """ + sql = """ + SELECT user_id, other_user_id FROM users_who_share_rooms + WHERE room_id = ? AND (user_id = ? OR other_user_id = ?) + """ + return self._execute( + "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id + ) + + @defer.inlineCallbacks + def get_rooms_in_common_for_users(self, user_id, other_user_id): + """Given two user_ids find out the list of rooms they share. + """ + sql = """ + SELECT room_id FROM ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) AS f1 INNER JOIN ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) f2 USING (room_id) + """ + + rows = yield self._execute( + "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + ) + + defer.returnValue([room_id for room_id, in rows]) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + txn.call_after(self.get_user_in_public_room.invalidate_all) + txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all) + txn.call_after(self.get_if_users_share_a_room.invalidate_all) + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("room_id", "display_name", "avatar_url",), + allow_none=True, + desc="get_user_in_directory", + ) + + @cached() + def get_user_in_public_room(self, user_id): + return self._simple_select_one( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + retcols=("room_id",), + allow_none=True, + desc="get_user_in_public_room", + ) + + def get_user_directory_stream_pos(self): + return self._simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) + + @defer.inlineCallbacks + def search_user_dir(self, user_id, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + + if self.hs.config.user_directory_search_all_users: + # make s.user_id null to keep the ordering algorithm happy + join_clause = """ + CROSS JOIN (SELECT NULL as user_id) AS s + """ + join_args = () + where_clause = "1=1" + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + join_args = (user_id,) + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + + if isinstance(self.database_engine, PostgresEngine): + full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + + # We order by rank and then if they have profile info + # The ranking algorithm is hand tweaked for "best" results. Broadly + # the idea is we give a higher weight to exact matches. + # The array of numbers are the weights for the various part of the + # search: (domain, _, display name, localpart) + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory AS d USING (user_id) + %s + WHERE + %s + AND vector @@ to_tsquery('english', ?) + ORDER BY + (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) + * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) + * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) + * ( + 3 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + ) + DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % (join_clause, where_clause) + args = join_args + (full_query, exact_query, prefix_query, limit + 1,) + elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_sqlite(search_term) + + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory AS d USING (user_id) + %s + WHERE + %s + AND value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % (join_clause, where_clause) + args = join_args + (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self._execute( + "search_user_dir", self.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + defer.returnValue({ + "limited": limited, + "results": results, + }) + + +def _parse_query_sqlite(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) + + +def _parse_query_postgres(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + both = " & ".join("(%s:* | %s)" % (result, result,) for result in results) + exact = " & ".join("%s" % (result,) for result in results) + prefix = " & ".join("%s:*" % (result,) for result in results) + + return both, exact, prefix diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 46cf93ff87..95031dc9ec 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -30,6 +30,17 @@ class IdGenerator(object): def _load_current_id(db_conn, table, column, step=1): + """ + + Args: + db_conn (object): + table (str): + column (str): + step (int): + + Returns: + int + """ cur = db_conn.cursor() if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) @@ -131,6 +142,9 @@ class StreamIdGenerator(object): def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Returns: + int """ with self._lock: if self._unfinished_ids: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 4f089bfb94..ca78e551cb 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -80,13 +80,13 @@ class PaginationConfig(object): from_tok = None # For backwards compat. elif from_tok: from_tok = StreamToken.from_string(from_tok) - except: + except Exception: raise SynapseError(400, "'from' paramater is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) - except: + except Exception: raise SynapseError(400, "'to' paramater is invalid") limit = get_param("limit", None) @@ -98,7 +98,7 @@ class PaginationConfig(object): try: return PaginationConfig(from_tok, to_tok, direction, limit) - except: + except Exception: logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 91a59b0bae..f03ad99118 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -45,6 +45,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -65,6 +66,7 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) @@ -73,6 +75,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -93,5 +96,6 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 9666f9d73f..cc7c182a78 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -12,26 +12,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import string from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", [ +class Requester(namedtuple("Requester", [ "user", "access_token_id", "is_guest", "device_id", "app_service", -]) -""" -Represents the user making a request +])): + """ + Represents the user making a request -Attributes: - user (UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user -""" + Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user + """ + + def serialize(self): + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + } + + @staticmethod + def deserialize(store, input): + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store (DataStore): Used to convert AS ID to AS object + input (dict): A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + device_id=input["device_id"], + app_service=appservice, + ) def create_requester(user_id, access_token_id=None, is_guest=False, @@ -56,10 +96,17 @@ def create_requester(user_id, access_token_id=None, is_guest=False, def get_domain_from_id(string): - try: - return string.split(":", 1)[1] - except IndexError: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1:] + + +def get_localpart_from_id(string): + idx = string.find(":") + if idx == -1: raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] class DomainSpecificString( @@ -119,14 +166,10 @@ class DomainSpecificString( try: cls.from_string(s) return True - except: + except Exception: return False - __str__ = to_string - - @classmethod - def create(cls, localpart, domain,): - return cls(localpart=localpart, domain=domain) + __repr__ = to_string class UserID(DomainSpecificString): @@ -149,6 +192,43 @@ class EventID(DomainSpecificString): SIGIL = "$" +class GroupID(DomainSpecificString): + """Structure representing a group ID.""" + SIGIL = "+" + + @classmethod + def from_string(cls, s): + group_id = super(GroupID, cls).from_string(s) + if not group_id.localpart: + raise SynapseError( + 400, + "Group ID cannot be empty", + ) + + if contains_invalid_mxid_characters(group_id.localpart): + raise SynapseError( + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + ) + + return group_id + + +mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits) + + +def contains_invalid_mxid_characters(localpart): + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart (basestring): the localpart to be checked + + Returns: + bool: True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + class StreamToken( namedtuple("Token", ( "room_key", @@ -159,6 +239,7 @@ class StreamToken( "push_rules_key", "to_device_key", "device_list_key", + "groups_key", )) ): _SEPARATOR = "_" @@ -171,7 +252,7 @@ class StreamToken( # i.e. old token from before receipt_key keys.append("0") return cls(*keys) - except: + except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): @@ -197,6 +278,7 @@ class StreamToken( or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.device_list_key) < int(self.device_list_key)) + or (int(other.groups_key) < int(self.groups_key)) ) def copy_and_advance(self, key, new_value): @@ -216,9 +298,7 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - d = self._asdict() - d[key] = new_value - return StreamToken(**d) + return self._replace(**{key: new_value}) StreamToken.START = StreamToken( @@ -258,7 +338,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): if string[0] == 't': parts = string[1:].split('-', 1) return cls(topological=int(parts[0]), stream=int(parts[1])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -267,7 +347,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): try: if string[0] == 's': return cls(topological=None, stream=int(string[1:])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 30fc480108..814a7bf71b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -24,11 +23,6 @@ import logging logger = logging.getLogger(__name__) -class DeferredTimedOutError(SynapseError): - def __init__(self): - super(SynapseError).__init__(504, "Timed out") - - def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -59,9 +53,9 @@ class Clock(object): f(function): The function to call repeatedly. msec(float): How long to wait between calls in milliseconds. """ - l = task.LoopingCall(f) - l.start(msec / 1000.0, now=False) - return l + call = task.LoopingCall(f) + call.start(msec / 1000.0, now=False) + return call def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -82,54 +76,6 @@ class Clock(object): def cancel_call_later(self, timer, ignore_errs=False): try: timer.cancel() - except: + except Exception: if not ignore_errs: raise - - def time_bound_deferred(self, given_deferred, time_out): - if given_deferred.called: - return given_deferred - - ret_deferred = defer.Deferred() - - def timed_out_fn(): - try: - ret_deferred.errback(DeferredTimedOutError()) - except: - pass - - try: - given_deferred.cancel() - except: - pass - - timer = None - - def cancel(res): - try: - self.cancel_call_later(timer) - except: - pass - return res - - ret_deferred.addBoth(cancel) - - def sucess(res): - try: - ret_deferred.callback(res) - except: - pass - - return res - - def err(res): - try: - ret_deferred.errback(res) - except: - pass - - given_deferred.addCallbacks(callback=sucess, errback=err) - - timer = self.call_later(time_out, timed_out_fn) - - return ret_deferred diff --git a/synapse/util/async.py b/synapse/util/async.py index 35380bf8ed..9dd4e6b5bc 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -15,16 +15,20 @@ from twisted.internet import defer, reactor +from twisted.internet.defer import CancelledError +from twisted.python import failure from .logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, + PreserveLoggingContext, make_deferred_yieldable, run_in_background ) -from synapse.util import unwrapFirstError +from synapse.util import logcontext, unwrapFirstError from contextlib import contextmanager import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -53,6 +57,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] @@ -68,7 +77,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().callback(r) - except: + except Exception: pass return r @@ -78,7 +87,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().errback(f) - except: + except Exception: pass if consumeErrors: @@ -89,6 +98,11 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ if not self._result: d = defer.Deferred() @@ -101,7 +115,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return res if success else defer.fail(res) def observers(self): return self._observers @@ -146,13 +160,13 @@ def concurrently_execute(func, args, limit): def _concurrently_execute_inner(): try: while True: - yield func(it.next()) + yield func(next(it)) except StopIteration: pass - return preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(_concurrently_execute_inner)() - for _ in xrange(limit) + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError) @@ -195,10 +209,29 @@ class Linearizer(object): try: with PreserveLoggingContext(): yield current_defer - except: + except Exception: logger.exception("Unexpected exception in Linearizer") - logger.info("Acquired linearizer lock %r for key %r", self.name, key) + logger.info("Acquired linearizer lock %r for key %r", self.name, + key) + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (There's no particular need for it to happen before we return + # the context manager, but it needs to happen while we hold the + # lock, and the context manager's exit code must be synchronous, + # so actually this is the only sensible place. + yield run_on_reactor() + + else: + logger.info("Acquired uncontended linearizer lock %r for key %r", + self.name, key) @contextmanager def _ctx_manager(): @@ -206,7 +239,8 @@ class Linearizer(object): yield finally: logger.info("Releasing linearizer lock %r for key %r", self.name, key) - new_defer.callback(None) + with PreserveLoggingContext(): + new_defer.callback(None) current_d = self.key_to_defer.get(key) if current_d is new_defer: self.key_to_defer.pop(key, None) @@ -248,8 +282,13 @@ class Limiter(object): if entry[0] >= self.max_count: new_defer = defer.Deferred() entry[1].append(new_defer) + + logger.info("Waiting to acquire limiter lock for key %r", key) with PreserveLoggingContext(): yield new_defer + logger.info("Acquired limiter lock for key %r", key) + else: + logger.info("Acquired uncontended limiter lock for key %r", key) entry[0] += 1 @@ -258,16 +297,21 @@ class Limiter(object): try: yield finally: + logger.info("Releasing limiter lock for key %r", key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + next_def = entry[1].pop(0) + + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) @@ -311,7 +355,7 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield curr_writer + yield make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -340,7 +384,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): @@ -352,3 +396,68 @@ class ReadWriteLock(object): self.key_to_current_writer.pop(key) defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 8a7774a88e..4adae96681 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -14,12 +14,9 @@ # limitations under the License. import synapse.metrics -from lrucache import LruCache import os -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) - -DEBUG_CACHES = False +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) metrics = synapse.metrics.get_metrics_for("synapse.util.caches") @@ -40,10 +37,6 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) -_stirng_cache_metrics = register_cache("string_cache", _string_cache) - - KNOWN_KEYS = { key: key for key in ( @@ -67,14 +60,16 @@ KNOWN_KEYS = { def intern_string(string): - """Takes a (potentially) unicode string and interns using custom cache + """Takes a (potentially) unicode string and interns it if it's ascii """ - new_str = _string_cache.setdefault(string, string) - if new_str is string: - _stirng_cache_metrics.inc_hits() - else: - _stirng_cache_metrics.inc_misses() - return new_str + if string is None: + return None + + try: + string = string.encode("ascii") + return intern(string) + except UnicodeEncodeError: + return string def intern_dict(dictionary): @@ -87,13 +82,9 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_str_keys = ("event_id", "room_id") - intern_unicode_keys = ("sender", "user_id", "type", "state_key") - - if key in intern_str_keys: - return intern(value.encode('ascii')) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) - if key in intern_unicode_keys: + if key in intern_keys: return intern_string(value) return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 998de70d29..68285a7594 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,19 +16,17 @@ import logging from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError +from synapse.util import unwrapFirstError, logcontext +from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn -) +from synapse.util.stringutils import to_ascii -from . import DEBUG_CACHES, register_cache +from . import register_cache from twisted.internet import defer from collections import namedtuple -import os import functools import inspect import threading @@ -39,17 +38,13 @@ logger = logging.getLogger(__name__) _CacheSentinel = object() -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) - - class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -67,7 +62,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -79,15 +73,18 @@ class Cache(object): self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, - size_callback=(lambda d: len(d.result)) if iterable else None, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, ) self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None self.metrics = register_cache(name, self.cache) + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -98,21 +95,34 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel, callback=None): + def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either a Deferred or the raw result + """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) + val.callbacks.update(callbacks) + if update_metrics: self.metrics.inc_hits() - return val.deferred + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val - self.metrics.inc_misses() + if update_metrics: + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() @@ -124,12 +134,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -137,13 +144,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, entry.deferred, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -155,29 +174,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() - if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -185,11 +204,73 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in self._pending_deferred_cache.itervalues(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class _CacheDescriptorBase(object): + def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + arg_spec = inspect.getargspec(orig) + all_args = arg_spec.args -class CacheDescriptor(object): + if "cache_context" in all_args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + if num_args is None: + num_args = len(all_args) - 1 + if cache_context: + num_args -= 1 + + if len(all_args) < num_args + 1: + raise Exception( + "Not enough explicit positional arguments to key off for %r: " + "got %i args, but wanted %i. (@cached cannot key off *args or " + "**kwargs)" + % (orig.__name__, len(all_args), num_args) + ) + + self.num_args = num_args + + # list of the names of the args used as the cache key + self.arg_names = all_args[1:num_args + 1] + + # self.arg_defaults is a map of arg name to its default value for each + # argument that has a default value + if arg_spec.defaults: + self.arg_defaults = dict(zip( + all_args[-len(arg_spec.defaults):], + arg_spec.defaults + )) + else: + self.arg_defaults = {} + + if "cache_context" in self.arg_names: + raise Exception( + "cache_context arg cannot be included among the cache keys" + ) + + self.add_cache_context = cache_context + + +class CacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -217,52 +298,24 @@ class CacheDescriptor(object): r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) defer.returnValue(r1 + r2) + Args: + num_args (int): number of positional arguments (excluding ``self`` and + ``cache_context``) to use as cache keys. Defaults to all named + args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + def __init__(self, orig, max_entries=1000, num_args=None, tree=False, inlineCallbacks=False, cache_context=False, iterable=False): - max_entries = int(max_entries * CACHE_SIZE_FACTOR) - self.orig = orig + super(CacheDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks, + cache_context=cache_context) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig + max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.max_entries = max_entries - self.num_args = num_args self.tree = tree - self.iterable = iterable - all_args = inspect.getargspec(orig) - self.arg_names = all_args.args[1:num_args + 1] - - if "cache_context" in all_args.args: - if not cache_context: - raise ValueError( - "Cannot have a 'cache_context' arg without setting" - " cache_context=True" - ) - try: - self.arg_names.remove("cache_context") - except ValueError: - pass - elif cache_context: - raise ValueError( - "Cannot have cache_context=True without having an arg" - " named `cache_context`" - ) - - self.add_cache_context = cache_context - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwargs)" - % (orig.__name__,) - ) - def __get__(self, obj, objtype=None): cache = Cache( name=self.orig.__name__, @@ -272,18 +325,47 @@ class CacheDescriptor(object): iterable=self.iterable, ) + def get_cache_key_gen(args, kwargs): + """Given some args/kwargs return a generator that resolves into + the cache_key. + + We loop through each arg name, looking up if its in the `kwargs`, + otherwise using the next argument in `args`. If there are no more + args then we try looking the arg name up in the defaults + """ + pos = 0 + for nm in self.arg_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield self.arg_defaults[nm] + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + if self.num_args == 1: + nm = self.arg_names[0] + + def get_cache_key(args, kwargs): + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): + return tuple(get_cache_key_gen(args, kwargs)) + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add temp cache_context so inspect.getcallargs doesn't explode - if self.add_cache_context: - kwargs["cache_context"] = None - - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = get_cache_key(args, kwargs) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -293,26 +375,14 @@ class CacheDescriptor(object): try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) - - return preserve_context_over_deferred(observer) + if isinstance(cached_result_d, ObservableDeferred): + observer = cached_result_d.observe() + else: + observer = cached_result_d + except KeyError: ret = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) @@ -322,64 +392,72 @@ class CacheDescriptor(object): ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, ret, callback=invalidate_callback) + # If our cache_key is a string, try to convert to ascii to save + # a bit of space in large caches + if isinstance(cache_key, basestring): + cache_key = to_ascii(cache_key) - return preserve_context_over_deferred(ret.observe()) + result_d = ObservableDeferred(ret, consumeErrors=True) + cache.set(cache_key, result_d, callback=invalidate_callback) + observer = result_d.observe() + + if isinstance(observer, defer.Deferred): + return logcontext.make_deferred_yieldable(observer) + else: + return observer + + if self.num_args == 1: + wrapped.invalidate = lambda key: cache.invalidate(key[0]) + wrapped.prefill = lambda key, val: cache.prefill(key[0], val) + else: + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill - wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all - wrapped.invalidate_many = cache.invalidate_many - wrapped.prefill = cache.prefill wrapped.cache = cache + wrapped.num_args = self.num_args obj.__dict__[self.orig.__name__] = wrapped return wrapped -class CacheListDescriptor(object): +class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. + the list of missing keys to the wrapped function. + + Once wrapped, the function returns either a Deferred which resolves to + the list of results, or (if all results were cached), just the list of + results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=1, + def __init__(self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False): """ Args: orig (function) - method_name (str); The name of the chached method. + cached_method_name (str): The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list - num_args (int) + num_args (int): number of positional arguments (excluding ``self``, + but including list_name) to use as cache keys. Defaults to all + named args of the function. inlineCallbacks (bool): Whether orig is a generator that should be wrapped by defer.inlineCallbacks """ - self.orig = orig + super(CacheListDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cached_method_name = cached_method_name self.sentinel = object() - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." @@ -387,8 +465,9 @@ class CacheListDescriptor(object): ) def __get__(self, obj, objtype=None): - - cache = getattr(obj, self.cached_method_name).cache + cached_method = getattr(obj, self.cached_method_name) + cache = cached_method.cache + num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args, **kwargs): @@ -405,13 +484,26 @@ class CacheListDescriptor(object): results = {} cached_defers = {} missing = [] - for arg in list_args: + + # If the cache takes a single arg then that is used as the key, + # otherwise a tuple is used. + if num_args == 1: + def cache_get(arg): + return cache.get(arg, callback=invalidate_callback) + else: key = list(keyargs) - key[self.list_pos] = arg + def cache_get(arg): + key[self.list_pos] = arg + return cache.get(tuple(key), callback=invalidate_callback) + + for arg in list_args: try: - res = cache.get(tuple(key), callback=invalidate_callback) - if not res.has_succeeded(): + res = cache_get(arg) + + if not isinstance(res, ObservableDeferred): + results[arg] = res + elif not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res @@ -425,8 +517,7 @@ class CacheListDescriptor(object): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), **args_to_call ) @@ -435,23 +526,33 @@ class CacheListDescriptor(object): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - with PreserveLoggingContext(): - observer = ret_d.observe() + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) + if num_args == 1: + cache.set( + arg, observer, + callback=invalidate_callback + ) - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, arg) + else: + key = list(keyargs) + key[self.list_pos] = arg + cache.set( + tuple(key), observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) @@ -463,7 +564,7 @@ class CacheListDescriptor(object): results.update(res) return results - return preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( cached_defers.values(), consumeErrors=True, ).addCallback(update_results_dict).addErrback( @@ -487,7 +588,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, +def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, @@ -499,8 +600,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, - iterable=False): +def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, + cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -512,7 +613,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex ) -def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -525,7 +626,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) cache (Cache): The underlying cache to use. list_name (str): The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache. + num_args (int): Number of arguments to use as the key in the cache + (including list_name). Defaults to all named parameters. inlineCallbacks (bool): Should the function be wrapped in an `defer.inlineCallbacks`? diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index cb6933c61c..1709e8b429 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,17 @@ import logging logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): + """Returned when getting an entry from the cache + + Attributes: + full (bool): Whether the cache has the full or dict or just some keys. + If not full then not all requested keys will necessarily be present + in `value` + known_absent (set): Keys that were looked up in the dict and were not + there. + value (dict): The full or partial dict value + """ def __len__(self): return len(self.value) @@ -58,21 +68,31 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): + """Fetch an entry out of the cache + + Args: + key + dict_key(list): If given a set of keys then return only those keys + that exist in the cache. + + Returns: + DictionaryEntry + """ entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) + return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) else: - return DictionaryEntry(entry.full, { + return DictionaryEntry(entry.full, entry.known_absent, { k: entry.value[k] for k in dict_keys if k in entry.value }) self.metrics.inc_misses() - return DictionaryEntry(False, {}) + return DictionaryEntry(False, set(), {}) def invalidate(self, key): self.check_thread() @@ -87,19 +107,38 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False): + def update(self, sequence, key, value, full=False, known_absent=None): + """Updates the entry in the cache + + Args: + sequence + key + value (dict): The value to update the cache with. + full (bool): Whether the given value is the full dict, or just a + partial subset there of. If not full then any existing entries + for the key will be updated. + known_absent (set): Set of keys that we know don't exist in the full + dict. + """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) + if known_absent is None: + known_absent = set() if full: - self._insert(key, value) + self._insert(key, value, known_absent) else: - self._update_or_insert(key, value) + self._update_or_insert(key, value, known_absent) + + def _update_or_insert(self, key, value, known_absent): + # We pop and reinsert as we need to tell the cache the size may have + # changed - def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) + entry.known_absent.update(known_absent) + self.cache[key] = entry - def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) + def _insert(self, key, value, known_absent): + self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2987c38a2d..0aa103eecb 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -79,7 +79,11 @@ class ExpiringCache(object): while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - self._size_estimate -= len(value.value) + removed_len = len(value.value) + self.metrics.inc_evictions(removed_len) + self._size_estimate -= removed_len + else: + self.metrics.inc_evictions() def __getitem__(self, key): try: @@ -94,12 +98,22 @@ class ExpiringCache(object): return entry.value + def __contains__(self, key): + return key in self._cache + def get(self, key, default=None): try: return self[key] except KeyError: return default + def setdefault(self, key, value): + try: + return self[key] + except KeyError: + self[key] = value + return value + def _prune_cache(self): if not self._expiry_ms: # zero expiry time means don't expire. This should never get called diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index cf5fbb679c..1c5a982094 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,24 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, + evicted_callback=None): + """ + Args: + max_size (int): + + keylen (int): + + cache_type (type): + type of underlying cache to be used. Typically one of dict + or TreeCache. + + size_callback (func(V) -> int | None): + + evicted_callback (func(int)|None): + if not None, called on eviction with the size of the evicted + entry + """ cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -61,8 +78,10 @@ class LruCache(object): def evict(): while cache_len() > max_size: todelete = list_root.prev_node - delete_node(todelete) + evicted_len = delete_node(todelete) cache.pop(todelete.key, None) + if evicted_callback: + evicted_callback(evicted_len) def synchronized(f): @wraps(f) @@ -111,12 +130,15 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + deleted_len = 1 if size_callback: - cached_cache_len[0] -= size_callback(node.value) + deleted_len = size_callback(node.value) + cached_cache_len[0] -= deleted_len for cb in node.callbacks: cb() node.callbacks.clear() + return deleted_len @synchronized def cache_get(key, default=None, callbacks=[]): @@ -132,14 +154,21 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: - if value != node.value: + # We sometimes store large objects, e.g. dicts, which cause + # the inequality check to take a long time. So let's only do + # the check if we have some callbacks to call. + if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() - if size_callback: - cached_cache_len[0] -= size_callback(node.value) - cached_cache_len[0] += size_callback(value) + # We don't bother to protect this by value != node.value as + # generally size_callback will be cheap compared with equality + # checks. (For example, taking the size of two dicts is quicker + # than comparing them for equality.) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) node.callbacks.update(callbacks) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 00af539880..7f79333e96 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,8 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from twisted.internet import defer from synapse.util.async import ObservableDeferred +from synapse.util.caches import metrics as cache_metrics +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + +logger = logging.getLogger(__name__) class ResponseCache(object): @@ -24,20 +31,68 @@ class ResponseCache(object): used rather than trying to compute a new response. """ - def __init__(self, hs, timeout_ms=0): + def __init__(self, hs, name, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000. + self._name = name + self._metrics = cache_metrics.register_cache( + "response_cache", + size_callback=lambda: self.size(), + cache_name=name, + ) + + def size(self): + return len(self.pending_result_cache) + def get(self, key): + """Look up the given key. + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if the request has completed, the actual + result. You will probably want to make_deferred_yieldable the result. + + If there is no entry for the key, returns None. It is worth noting that + this means there is no way to distinguish a completed result of None + from an absent cache entry. + + Args: + key (hashable): + + Returns: + twisted.internet.defer.Deferred|None|E: None if there is no entry + for this key; otherwise either a deferred result or the result + itself. + """ result = self.pending_result_cache.get(key) if result is not None: + self._metrics.inc_hits() return result.observe() else: + self._metrics.inc_misses() return None def set(self, key, deferred): + """Set the entry for the given key to the given deferred. + + *deferred* should run its callbacks in the sentinel logcontext (ie, + you should wrap normal synapse deferreds with + logcontext.run_in_background). + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if *deferred* was already complete, the actual + result. You will probably want to make_deferred_yieldable the result. + + Args: + key (hashable): + deferred (twisted.internet.defer.Deferred[T): + + Returns: + twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual + result. + """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -53,3 +108,52 @@ class ResponseCache(object): result.addBoth(remove) return result.observe() + + def wrap(self, key, callback, *args, **kwargs): + """Wrap together a *get* and *set* call, taking care of logcontexts + + First looks up the key in the cache, and if it is present makes it + follow the synapse logcontext rules and returns it. + + Otherwise, makes a call to *callback(*args, **kwargs)*, which should + follow the synapse logcontext rules, and adds the result to the cache. + + Example usage: + + @defer.inlineCallbacks + def handle_request(request): + # etc + defer.returnValue(result) + + result = yield response_cache.wrap( + key, + handle_request, + request, + ) + + Args: + key (hashable): key to get/set in the cache + + callback (callable): function to call if the key is not found in + the cache + + *args: positional parameters to pass to the callback, if it is used + + **kwargs: named paramters to pass to the callback, if it is used + + Returns: + twisted.internet.defer.Deferred: yieldable result + """ + result = self.get(key) + if not result: + logger.info("[%s]: no cached result for [%s], calculating new one", + self._name, key) + d = run_in_background(callback, *args, **kwargs) + result = self.set(key, d) + elif not isinstance(result, defer.Deferred) or result.called: + logger.info("[%s]: using completed cached result for [%s]", + self._name, key) + else: + logger.info("[%s]: using incomplete cached result for [%s]", + self._name, key) + return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index b72bb0ff02..941d873ab8 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,20 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache +from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR from blist import sorteddict import logging -import os logger = logging.getLogger(__name__) -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) - - class StreamChangeCache(object): """Keeps track of the stream positions of the latest change in a set of entities. @@ -50,7 +46,7 @@ class StreamChangeCache(object): def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) is int + assert type(stream_pos) is int or type(stream_pos) is long if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() @@ -89,6 +85,21 @@ class StreamChangeCache(object): return result + def has_any_entity_changed(self, stream_pos): + """Returns if any entity has changed + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + self.metrics.inc_hits() + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + return i < len(keys) + else: + self.metrics.inc_misses() + return True + def get_all_entities_changed(self, stream_pos): """Returns all entites that have had new things since the given position. If the position is too old it will return None. diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e68f94ce77..734331caaa 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,32 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_fn -) +from twisted.internet import defer from synapse.util import unwrapFirstError - -import logging - +from synapse.util.logcontext import PreserveLoggingContext logger = logging.getLogger(__name__) def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) + with PreserveLoggingContext(): + distributor.fire("user_left_room", user=user, room_id=room_id) def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) + with PreserveLoggingContext(): + distributor.fire("user_joined_room", user=user, room_id=room_id) class Distributor(object): diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py new file mode 100644 index 0000000000..3380970e4e --- /dev/null +++ b/synapse/util/file_consumer.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import threads, reactor + +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + +from six.moves import queue + + +class BackgroundFileConsumer(object): + """A consumer that writes to a file like object. Supports both push + and pull producers + + Args: + file_obj (file): The file like object to write to. Closed when + finished. + """ + + # For PushProducers pause if we have this many unwritten slices + _PAUSE_ON_QUEUE_SIZE = 5 + # And resume once the size of the queue is less than this + _RESUME_ON_QUEUE_SIZE = 2 + + def __init__(self, file_obj): + self._file_obj = file_obj + + # Producer we're registered with + self._producer = None + + # True if PushProducer, false if PullProducer + self.streaming = False + + # For PushProducers, indicates whether we've paused the producer and + # need to call resumeProducing before we get more data. + self._paused_producer = False + + # Queue of slices of bytes to be written. When producer calls + # unregister a final None is sent. + self._bytes_queue = queue.Queue() + + # Deferred that is resolved when finished writing + self._finished_deferred = None + + # If the _writer thread throws an exception it gets stored here. + self._write_exception = None + + def registerProducer(self, producer, streaming): + """Part of IConsumer interface + + Args: + producer (IProducer) + streaming (bool): True if push based producer, False if pull + based. + """ + if self._producer: + raise Exception("registerProducer called twice") + + self._producer = producer + self.streaming = streaming + self._finished_deferred = run_in_background( + threads.deferToThread, self._writer + ) + if not streaming: + self._producer.resumeProducing() + + def unregisterProducer(self): + """Part of IProducer interface + """ + self._producer = None + if not self._finished_deferred.called: + self._bytes_queue.put_nowait(None) + + def write(self, bytes): + """Part of IProducer interface + """ + if self._write_exception: + raise self._write_exception + + if self._finished_deferred.called: + raise Exception("consumer has closed") + + self._bytes_queue.put_nowait(bytes) + + # If this is a PushProducer and the queue is getting behind + # then we pause the producer. + if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + self._paused_producer = True + self._producer.pauseProducing() + + def _writer(self): + """This is run in a background thread to write to the file. + """ + try: + while self._producer or not self._bytes_queue.empty(): + # If we've paused the producer check if we should resume the + # producer. + if self._producer and self._paused_producer: + if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: + reactor.callFromThread(self._resume_paused_producer) + + bytes = self._bytes_queue.get() + + # If we get a None (or empty list) then that's a signal used + # to indicate we should check if we should stop. + if bytes: + self._file_obj.write(bytes) + + # If its a pull producer then we need to explicitly ask for + # more stuff. + if not self.streaming and self._producer: + reactor.callFromThread(self._producer.resumeProducing) + except Exception as e: + self._write_exception = e + raise + finally: + self._file_obj.close() + + def wait(self): + """Returns a deferred that resolves when finished writing to file + """ + return make_deferred_yieldable(self._finished_deferred) + + def _resume_paused_producer(self): + """Gets called if we should resume producing after being paused + """ + if self._paused_producer and self._producer: + self._paused_producer = False + self._producer.resumeProducing() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 6322f0f55c..f497b51f4a 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -14,6 +14,7 @@ # limitations under the License. from frozendict import frozendict +import simplejson as json def freeze(o): @@ -49,3 +50,21 @@ def unfreeze(o): pass return o + + +def _handle_frozendict(obj): + """Helper for EventEncoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError('Object of type %s is not JSON serializable' % + obj.__class__.__name__) + + +# A JSONEncoder which is capable of encoding frozendics without barfing +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, +) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 45be47159a..e9f0f292ee 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource +from twisted.web.resource import NoResource import logging @@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource): # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} for full_path, res in desired_tree.items(): + # twisted requires all resources to be bytes + full_path = full_path.encode("utf-8") + logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split('/')[1:-1]: + for path_seg in full_path.split(b'/')[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = Resource() + child_resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split('/')[-1] + last_path_seg = full_path.split(b'/')[-1] # if there is already a resource here, thieve its children and # replace it diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 6c83eb213d..e086e12213 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -12,6 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" Thread-local-alike tracking of log contexts within synapse + +This module provides objects and utilities for tracking contexts through +synapse code, so that log lines can include a request identifier, and so that +CPU and database activity can be accounted for against the request that caused +them. + +See doc/log_contexts.rst for details on how this works. +""" + from twisted.internet import defer import threading @@ -32,7 +42,7 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) -except: +except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. def get_thread_resource_usage(): @@ -42,13 +52,17 @@ except: class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + Args: name (str): Name for the context for debugging. """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "usage_start", "usage_end", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -73,8 +87,12 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): pass + def add_database_scheduled(self, sched_ms): + pass + def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -84,9 +102,17 @@ class LoggingContext(object): self.ru_stime = 0. self.ru_utime = 0. self.db_txn_count = 0 - self.db_txn_duration = 0. + + # ms spent waiting for db txns, excluding scheduling time + self.db_txn_duration_ms = 0 + + # ms spent waiting for db txns to be scheduled + self.db_sched_duration_ms = 0 + self.usage_start = None + self.usage_end = None self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True @@ -95,7 +121,11 @@ class LoggingContext(object): @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -145,11 +175,13 @@ class LoggingContext(object): self.alive = False def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: @@ -184,7 +216,16 @@ class LoggingContext(object): def add_database_transaction(self, duration_ms): self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + self.db_txn_duration_ms += duration_ms + + def add_database_scheduled(self, sched_ms): + """Record a use of the database pool + + Args: + sched_ms (int): number of milliseconds it took us to get a + connection + """ + self.db_sched_duration_ms += sched_ms class LoggingContextFilter(logging.Filter): @@ -251,80 +292,94 @@ class PreserveLoggingContext(object): ) -class _PreservingContextDeferred(defer.Deferred): - """A deferred that ensures that all callbacks and errbacks are called with - the given logging context. - """ - def __init__(self, context): - self._log_context = context - defer.Deferred.__init__(self) - - def addCallbacks(self, callback, errback=None, - callbackArgs=None, callbackKeywords=None, - errbackArgs=None, errbackKeywords=None): - callback = self._wrap_callback(callback) - errback = self._wrap_callback(errback) - return defer.Deferred.addCallbacks( - self, callback, - errback=errback, - callbackArgs=callbackArgs, - callbackKeywords=callbackKeywords, - errbackArgs=errbackArgs, - errbackKeywords=errbackKeywords, - ) +def preserve_fn(f): + """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + return g - def _wrap_callback(self, f): - def g(res, *args, **kwargs): - with PreserveLoggingContext(self._log_context): - res = f(res, *args, **kwargs) - return res - return g +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. -def preserve_context_over_fn(fn, *args, **kwargs): - """Takes a function and invokes it with the given arguments, but removes - and restores the current logging context while doing so. + Useful for wrapping functions that return a deferred which you don't yield + on (for instance because you want to pass it to deferred.gatherResults()). - If the result is a deferred, call preserve_context_over_deferred before - returning it. + Note that if you completely discard the result, you should make sure that + `f` doesn't raise any deferred exceptions, otherwise a scary-looking + CRITICAL error about an unhandled error will be logged without much + indication about where it came from. """ - with PreserveLoggingContext(): - res = fn(*args, **kwargs) - - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred(res) - else: + current = LoggingContext.current_context() + try: + res = f(*args, **kwargs) + except: # noqa: E722 + # the assumption here is that the caller doesn't want to be disturbed + # by synchronous exceptions, so let's turn them into Failures. + return defer.fail() + + if not isinstance(res, defer.Deferred): return res + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about + return res -def preserve_context_over_deferred(deferred, context=None): - """Given a deferred wrap it such that any callbacks added later to it will - be invoked with the current context. + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) + return res + + +def make_deferred_yieldable(deferred): + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed (or is not actually a Deferred), essentially + does nothing (just returns another completed deferred with the + result/failure). + + If the deferred has not yet completed, resets the logcontext before + returning a deferred. Then, when the deferred completes, restores the + current logcontext before running callbacks/errbacks. + + (This is more-or-less the opposite operation to run_in_background.) """ - if context is None: - context = LoggingContext.current_context() - d = _PreservingContextDeferred(context) - deferred.chainDeferred(d) - return d + if not isinstance(deferred, defer.Deferred): + return deferred + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred -def preserve_fn(f): - """Ensures that function is called with correct context and that context is - restored after return. Useful for wrapping functions that return a deferred - which you don't yield on. - """ - current = LoggingContext.current_context() + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred - def g(*args, **kwargs): - with PreserveLoggingContext(current): - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred( - res, context=LoggingContext.sentinel - ) - else: - return res - return g + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result # modules to ignore in `logcontext_tracer` diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py new file mode 100644 index 0000000000..3e42868ea9 --- /dev/null +++ b/synapse/util/logformatter.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from six import StringIO +import logging +import traceback + + +class LogFormatter(logging.Formatter): + """Log formatter which gives more detail for exceptions + + This is the same as the standard log formatter, except that when logging + exceptions [typically via log.foo("msg", exc_info=1)], it prints the + sequence that led up to the point at which the exception was caught. + (Normally only stack frames between the point the exception was raised and + where it was caught are logged). + """ + def __init__(self, *args, **kwargs): + super(LogFormatter, self).__init__(*args, **kwargs) + + def formatException(self, ei): + sio = StringIO() + (typ, val, tb) = ei + + # log the stack above the exception capture point if possible, but + # check that we actually have an f_back attribute to work around + # https://twistedmatrix.com/trac/ticket/9305 + + if tb and hasattr(tb.tb_frame, 'f_back'): + sio.write("Capture point (most recent call last):\n") + traceback.print_stack(tb.tb_frame.f_back, None, sio) + + traceback.print_exception(typ, val, tb, None, sio) + s = sio.getvalue() + sio.close() + if s[-1:] == "\n": + s = s[:-1] + return s diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4ea930d3e8..e4b5687a4b 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -27,25 +27,62 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -block_timer = metrics.register_distribution( - "block_timer", - labels=["block_name"] +# total number of times we have hit this block +block_counter = metrics.register_counter( + "block_count", + labels=["block_name"], + alternative_names=( + # the following are all deprecated aliases for the same metric + metrics.name_prefix + x for x in ( + "_block_timer:count", + "_block_ru_utime:count", + "_block_ru_stime:count", + "_block_db_txn_count:count", + "_block_db_txn_duration:count", + ) + ) +) + +block_timer = metrics.register_counter( + "block_time_seconds", + labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_timer:total", + ), ) -block_ru_utime = metrics.register_distribution( - "block_ru_utime", labels=["block_name"] +block_ru_utime = metrics.register_counter( + "block_ru_utime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_utime:total", + ), ) -block_ru_stime = metrics.register_distribution( - "block_ru_stime", labels=["block_name"] +block_ru_stime = metrics.register_counter( + "block_ru_stime_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_ru_stime:total", + ), ) -block_db_txn_count = metrics.register_distribution( - "block_db_txn_count", labels=["block_name"] +block_db_txn_count = metrics.register_counter( + "block_db_txn_count", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_count:total", + ), ) -block_db_txn_duration = metrics.register_distribution( - "block_db_txn_duration", labels=["block_name"] +# seconds spent waiting for db txns, excluding scheduling time, in this block +block_db_txn_duration = metrics.register_counter( + "block_db_txn_duration_seconds", labels=["block_name"], + alternative_names=( + metrics.name_prefix + "_block_db_txn_duration:total", + ), +) + +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = metrics.register_counter( + "block_db_sched_duration_seconds", labels=["block_name"], ) @@ -64,7 +101,9 @@ def measure_func(name): class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "ru_stime", + "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", + "created_context", ] def __init__(self, clock, name): @@ -84,13 +123,16 @@ class Measure(object): self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.db_txn_duration_ms = self.start_context.db_txn_duration_ms + self.db_sched_duration_ms = self.start_context.db_sched_duration_ms def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: return duration = self.clock.time_msec() - self.start + + block_counter.inc(self.name) block_timer.inc_by(duration, self.name) context = LoggingContext.current_context() @@ -114,7 +156,12 @@ class Measure(object): context.db_txn_count - self.db_txn_count, self.name ) block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name + (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., + self.name + ) + block_db_sched_duration.inc_by( + (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000., + self.name ) if self.created_context: diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py new file mode 100644 index 0000000000..4288312b8a --- /dev/null +++ b/synapse/util/module_loader.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +from synapse.config._base import ConfigError + + +def load_module(provider): + """ Loads a module with its config + Take a dict with keys 'module' (the module name) and 'config' + (the config dict). + + Returns + Tuple of (provider class, parsed config object) + """ + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) + + return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py new file mode 100644 index 0000000000..607161e7f0 --- /dev/null +++ b/synapse/util/msisdn.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import phonenumbers +from synapse.api.errors import SynapseError + + +def phone_number_to_msisdn(country, number): + """ + Takes an ISO-3166-1 2 letter country code and phone number and + returns an msisdn representing the canonical version of that + phone number. + Args: + country (str): ISO-3166-1 2 letter country code + number (str): Phone number in a national or international format + + Returns: + (str) The canonical form of the phone number, as an msisdn + Raises: + SynapseError if the number could not be parsed. + """ + try: + phoneNumber = phonenumbers.parse(number, country) + except phonenumbers.NumberParseException: + raise SynapseError(400, "Unable to parse phone number") + return phonenumbers.format_number( + phoneNumber, phonenumbers.PhoneNumberFormat.E164 + )[1:] diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 1101881a2d..18424f6c36 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError from synapse.util.async import sleep -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background import collections import contextlib @@ -150,7 +150,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) + ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 153ef001ad..4e93f69d3a 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import synapse.util.logcontext from twisted.internet import defer from synapse.api.errors import CodeMessageException @@ -26,6 +26,18 @@ logger = logging.getLogger(__name__) class NotRetryingDestination(Exception): def __init__(self, retry_last_ts, retry_interval, destination): + """Raised by the limiter (and federation client) to indicate that we are + are deliberately not attempting to contact a given server. + + Args: + retry_last_ts (int): the unix ts in milliseconds of our last attempt + to contact the server. 0 indicates that the last attempt was + successful or that we've never actually attempted to connect. + retry_interval (int): the time in milliseconds to wait until the next + attempt. + destination (str): the domain in question + """ + msg = "Not retrying server %s." % (destination,) super(NotRetryingDestination, self).__init__(msg) @@ -35,7 +47,8 @@ class NotRetryingDestination(Exception): @defer.inlineCallbacks -def get_retry_limiter(destination, clock, store, **kwargs): +def get_retry_limiter(destination, clock, store, ignore_backoff=False, + **kwargs): """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -43,6 +56,14 @@ def get_retry_limiter(destination, clock, store, **kwargs): that will mark the destination as down if an exception is thrown (excluding CodeMessageException with code < 500) + Args: + destination (str): name of homeserver + clock (synapse.util.clock): timing source + store (synapse.storage.transactions.TransactionStore): datastore + ignore_backoff (bool): true to ignore the historical backoff data and + try the request anyway. We will still update the next + retry_interval on success/failure. + Example usage: try: @@ -66,7 +87,7 @@ def get_retry_limiter(destination, clock, store, **kwargs): now = int(clock.time_msec()) - if retry_last_ts + retry_interval > now: + if not ignore_backoff and retry_last_ts + retry_interval > now: raise NotRetryingDestination( retry_last_ts=retry_last_ts, retry_interval=retry_interval, @@ -124,7 +145,13 @@ class RetryDestinationLimiter(object): def __exit__(self, exc_type, exc_val, exc_tb): valid_err_code = False - if exc_type is not None and issubclass(exc_type, CodeMessageException): + if exc_type is None: + valid_err_code = True + elif not issubclass(exc_type, Exception): + # avoid treating exceptions which don't derive from Exception as + # failures; this is mostly so as not to catch defer._DefGen. + valid_err_code = True + elif issubclass(exc_type, CodeMessageException): # Some error codes are perfectly fine for some APIs, whereas other # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS @@ -142,11 +169,13 @@ class RetryDestinationLimiter(object): else: valid_err_code = False - if exc_type is None or valid_err_code: + if valid_err_code: # We connected successfully. if not self.retry_interval: return + logger.debug("Connection to %s was successful; clearing backoff", + self.destination) retry_last_ts = 0 self.retry_interval = 0 else: @@ -160,6 +189,10 @@ class RetryDestinationLimiter(object): else: self.retry_interval = self.min_retry_interval + logger.debug( + "Connection to %s was unsuccessful (%s(%s)); backoff now %i", + self.destination, exc_type, exc_val, self.retry_interval + ) retry_last_ts = int(self.clock.time_msec()) @defer.inlineCallbacks @@ -168,9 +201,10 @@ class RetryDestinationLimiter(object): yield self.store.set_destination_retry_timings( self.destination, retry_last_ts, self.retry_interval ) - except: + except Exception: logger.exception( - "Failed to store set_destination_retry_timings", + "Failed to store destination_retry_timings", ) - store_retry_timings() + # we deliberately do this in the background. + synapse.util.logcontext.run_in_background(store_retry_timings) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index a100f151d4..b98b9dc6e4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -15,6 +15,7 @@ import random import string +from six.moves import range _string_with_symbols = ( string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -22,12 +23,12 @@ _string_with_symbols = ( def random_string(length): - return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) + return ''.join(random.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): return ''.join( - random.choice(_string_with_symbols) for _ in xrange(length) + random.choice(_string_with_symbols) for _ in range(length) ) @@ -40,3 +41,17 @@ def is_ascii(s): return False else: return True + + +def to_ascii(s): + """Converts a string to ascii if it is ascii, otherwise leave it alone. + + If given None then will return None. + """ + if s is None: + return None + + try: + return s.encode("ascii") + except UnicodeEncodeError: + return s diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py new file mode 100644 index 0000000000..75efa0117b --- /dev/null +++ b/synapse/util/threepids.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +logger = logging.getLogger(__name__) + + +def check_3pid_allowed(hs, medium, address): + """Checks whether a given format of 3PID is allowed to be used on this HS + + Args: + hs (synapse.server.HomeServer): server + medium (str): 3pid medium - e.g. email, msisdn + address (str): address within that medium (e.g. "wotan@matrix.org") + msisdns need to first have been canonicalised + Returns: + bool: whether the 3PID medium/address is allowed to be added to this HS + """ + + if hs.config.allowed_local_3pids: + for constraint in hs.config.allowed_local_3pids: + logger.debug( + "Checking 3PID %s (%s) against %s (%s)", + address, medium, constraint['pattern'], constraint['medium'], + ) + if ( + medium == constraint['medium'] and + re.match(constraint['pattern'], address) + ): + return True + else: + return True + + return False diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7412fc57a4..7a9e45aca9 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import range + class _Entry(object): __slots__ = ["end_key", "queue"] @@ -68,7 +70,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. self.entries.extend( - _Entry(key) for key in xrange(last_key, then_key + 1) + _Entry(key) for key in range(last_key, then_key + 1) ) self.entries[-1].queue.append(obj) @@ -91,7 +93,4 @@ class WheelTimer(object): return ret def __len__(self): - l = 0 - for entry in self.entries: - l += len(entry.queue) - return l + return sum(len(entry.queue) for entry in self.entries) diff --git a/synapse/visibility.py b/synapse/visibility.py index 199b16d827..aaca2c584c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import Membership, EventTypes -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn import logging @@ -43,7 +43,8 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks -def filter_events_for_clients(store, user_tuples, events, event_id_to_state): +def filter_events_for_clients(store, user_tuples, events, event_id_to_state, + always_include_ids=frozenset()): """ Returns dict of user_id -> list of events that user is allowed to see. @@ -54,9 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): * the user has not been a member of the room since the given events events ([synapse.events.EventBase]): list of events to filter + always_include_ids (set(event_id)): set of event ids to specifically + include (unless sender is ignored) """ - forgotten = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.who_forgot_in_room)( + forgotten = yield make_deferred_yieldable(defer.gatherResults([ + defer.maybeDeferred( + preserve_fn(store.who_forgot_in_room), room_id, ) for room_id in frozenset(e.room_id for e in events) @@ -90,6 +94,9 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): if not event.is_state() and event.sender in ignore_list: return False + if event.event_id in always_include_ids: + return True + state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -134,6 +141,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): if prev_membership not in MEMBERSHIP_PRIORITY: prev_membership = "leave" + # Always allow the user to see their own leave events, otherwise + # they won't see the room disappear if they reject the invite + if membership == "leave" and ( + prev_membership == "join" or prev_membership == "invite" + ): + return True + new_priority = MEMBERSHIP_PRIORITY.index(membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) if old_priority < new_priority: @@ -181,26 +195,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): @defer.inlineCallbacks -def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): - user_ids = set(u[0] for u in user_tuples) - event_id_to_state = {} - for event_id, context in event_id_to_context.items(): - state = yield store.get_events([ - e_id - for key, e_id in context.current_state_ids.iteritems() - if key == (EventTypes.RoomHistoryVisibility, "") - or (key[0] == EventTypes.Member and key[1] in user_ids) - ]) - event_id_to_state[event_id] = state - - res = yield filter_events_for_clients( - store, user_tuples, events, event_id_to_state - ) - defer.returnValue(res) - - -@defer.inlineCallbacks -def filter_events_for_client(store, user_id, events, is_peeking=False): +def filter_events_for_client(store, user_id, events, is_peeking=False, + always_include_ids=frozenset()): """ Check which events a user is allowed to see @@ -224,6 +220,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False): types=types ) res = yield filter_events_for_clients( - store, [(user_id, is_peeking)], events, event_id_to_state + store, [(user_id, is_peeking)], events, event_id_to_state, + always_include_ids=always_include_ids, ) defer.returnValue(res.get(user_id, [])) |