diff options
Diffstat (limited to 'synapse')
332 files changed, 26535 insertions, 10409 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 6b0a766391..5c0f2f83aa 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,4 +17,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.22.0-rc1" +__version__ = "0.33.0" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f8266d1c81..073229b4c4 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -15,15 +15,19 @@ import logging +from six import itervalues + import pymacaroons +from netaddr import IPAddress + from twisted.internet import defer import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes from synapse.types import UserID -from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR +from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -57,16 +61,17 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) - register_cache("token_cache", self.token_cache) + register_cache("cache", "token_cache", self.token_cache) @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { - (e.type, e.state_key): e for e in auth_events.values() + (e.type, e.state_key): e for e in itervalues(auth_events) } self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) @@ -189,7 +194,7 @@ class Auth(object): synapse.types.create_requester(user_id, app_service=app_service) ) - access_token = get_access_token_from_request( + access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) @@ -204,12 +209,12 @@ class Auth(object): ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - "User-Agent", - default=[""] + b"User-Agent", + default=[b""] )[0] if user and access_token and ip_addr: self.store.insert_client_ip( - user=user, + user_id=user.to_string(), access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -235,13 +240,18 @@ class Auth(object): @defer.inlineCallbacks def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( - get_access_token_from_request( + self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) ) if app_service is None: defer.returnValue((None, None)) + if app_service.ip_range_whitelist: + ip_address = IPAddress(self.hs.get_ip_from_request(request)) + if ip_address not in app_service.ip_range_whitelist: + defer.returnValue((None, None)) + if "user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) @@ -270,7 +280,11 @@ class Auth(object): rights (str): The operation being performed; the access token must allow this. Returns: - dict : dict that includes the user and the ID of their access token. + Deferred[dict]: dict that includes: + `user` (UserID) + `is_guest` (bool) + `token_id` (int|None): access token id. May be None if guest + `device_id` (str|None): device corresponding to access token Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -482,7 +496,7 @@ class Auth(object): def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) if not ret: - logger.warn("Unrecognised access token - not in store: %s" % (token,)) + logger.warn("Unrecognised access token - not in store.") raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN @@ -500,12 +514,12 @@ class Auth(object): def get_appservice_by_req(self, request): try: - token = get_access_token_from_request( + token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token: %s" % (token,)) + logger.warn("Unrecognised appservice access token.") raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", @@ -519,11 +533,20 @@ class Auth(object): ) def is_server_admin(self, user): + """ Check if the given user is a local server admin. + + Args: + user (str): mxid of user to check + + Returns: + bool: True if the user is an admin + """ return self.store.is_server_admin(user) @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) + prev_state_ids = yield context.get_prev_state_ids(self.store) + auth_ids = yield self.compute_auth_events(builder, prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -641,7 +664,7 @@ class Auth(object): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", auth_events + EventTypes.Aliases, "", power_level_event, ) user_level = event_auth.get_user_power_level(user_id, auth_events) @@ -652,67 +675,101 @@ class Auth(object): " edit its room list entry" ) + @staticmethod + def has_access_token(request): + """Checks if the request has an access_token. -def has_access_token(request): - """Checks if the request has an access_token. + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + return bool(query_params) or bool(auth_headers) - Returns: - bool: False if no access_token was given, True otherwise. - """ - query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - return bool(query_params) or bool(auth_headers) - - -def get_access_token_from_request(request, token_not_found_http_status=401): - """Extracts the access_token from the request. - - Args: - request: The http request. - token_not_found_http_status(int): The HTTP status code to set in the - AuthError if the token isn't found. This is used in some of the - legacy APIs to change the status code to 403 from the default of - 401 since some of the old clients depended on auth errors returning - 403. - Returns: - str: The access_token - Raises: - AuthError: If there isn't an access_token in the request. - """ + @staticmethod + def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - query_params = request.args.get("access_token") - if auth_headers: - # Try the get the access_token from a "Authorization: Bearer" - # header - if query_params is not None: - raise AuthError( - token_not_found_http_status, - "Mixing Authorization headers and access_token query parameters.", - errcode=Codes.MISSING_TOKEN, - ) - if len(auth_headers) > 1: - raise AuthError( - token_not_found_http_status, - "Too many Authorization headers.", - errcode=Codes.MISSING_TOKEN, - ) - parts = auth_headers[0].split(" ") - if parts[0] == "Bearer" and len(parts) == 2: - return parts[1] + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") + if auth_headers: + # Try the get the access_token from a "Authorization: Bearer" + # header + if query_params is not None: + raise AuthError( + token_not_found_http_status, + "Mixing Authorization headers and access_token query parameters.", + errcode=Codes.MISSING_TOKEN, + ) + if len(auth_headers) > 1: + raise AuthError( + token_not_found_http_status, + "Too many Authorization headers.", + errcode=Codes.MISSING_TOKEN, + ) + parts = auth_headers[0].split(" ") + if parts[0] == "Bearer" and len(parts) == 2: + return parts[1] + else: + raise AuthError( + token_not_found_http_status, + "Invalid Authorization header.", + errcode=Codes.MISSING_TOKEN, + ) else: - raise AuthError( - token_not_found_http_status, - "Invalid Authorization header.", - errcode=Codes.MISSING_TOKEN, + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + + return query_params[0] + + @defer.inlineCallbacks + def check_in_room_or_world_readable(self, room_id, user_id): + """Checks that the user is or was in the room or the room is world + readable. If it isn't then an exception is raised. + + Returns: + Deferred[tuple[str, str|None]]: Resolves to the current membership of + the user in the room and the membership event ID of the user. If + the user is not in the room and never has been, then + `(Membership.JOIN, None)` is returned. + """ + + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + except AuthError: + visibility = yield self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" ) - else: - # Try to get the access_token from the query params. - if not query_params: + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN ) - - return query_params[0] diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 489efb7f86..4df930c8d1 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -16,6 +16,9 @@ """Contains constants from the specification.""" +# the "depth" field on events is limited to 2**63 - 1 +MAX_DEPTH = 2**63 - 1 + class Membership(object): @@ -73,6 +76,8 @@ class EventTypes(object): Topic = "m.room.topic" Name = "m.room.name" + ServerACL = "m.room.server_acl" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d0dfa959dc..6074df292f 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -15,9 +15,13 @@ """Contains exceptions and error codes.""" -import json import logging +from six import iteritems +from six.moves import http_client + +from canonicaljson import json + logger = logging.getLogger(__name__) @@ -46,8 +50,11 @@ class Codes(object): THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" + THREEPID_DENIED = "M_THREEPID_DENIED" INVALID_USERNAME = "M_INVALID_USERNAME" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" + CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" + CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" class CodeMessageException(RuntimeError): @@ -135,11 +142,79 @@ class SynapseError(CodeMessageException): return res +class ConsentNotGivenError(SynapseError): + """The error returned to the client when the user has not consented to the + privacy policy. + """ + def __init__(self, msg, consent_uri): + """Constructs a ConsentNotGivenError + + Args: + msg (str): The human-readable error message + consent_url (str): The URL where the user can give their consent + """ + super(ConsentNotGivenError, self).__init__( + code=http_client.FORBIDDEN, + msg=msg, + errcode=Codes.CONSENT_NOT_GIVEN + ) + self._consent_uri = consent_uri + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + consent_uri=self._consent_uri + ) + + class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass +class FederationDeniedError(SynapseError): + """An error raised when the server tries to federate with a server which + is not on its federation whitelist. + + Attributes: + destination (str): The destination which has been denied + """ + + def __init__(self, destination): + """Raised by federation client or server to indicate that we are + are deliberately not attempting to contact a given server because it is + not on our federation whitelist. + + Args: + destination (str): the domain in question + """ + + self.destination = destination + + super(FederationDeniedError, self).__init__( + code=403, + msg="Federation denied with %s." % (self.destination,), + errcode=Codes.FORBIDDEN, + ) + + +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): @@ -247,13 +322,13 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): Args: msg (str): The error message. - code (int): The error code. + code (str): The error code. kwargs : Additional keys to add to the response. Returns: A dict representing the error response JSON. """ err = {"error": msg, "errcode": code} - for key, value in kwargs.iteritems(): + for key, value in iteritems(kwargs): err[key] = value return err diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 83206348e5..25346baa87 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, RoomID -from twisted.internet import defer - -import ujson as json import jsonschema +from canonicaljson import json from jsonschema import FormatChecker +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage.presence import UserPresenceState +from synapse.types import RoomID, UserID + FILTER_SCHEMA = { "additionalProperties": False, "type": "object", @@ -411,7 +412,7 @@ class Filter(object): return room_ids def filter(self, events): - return filter(self.check, events) + return list(filter(self.check, events)) def limit(self): return self.filter_json.get("limit", 10) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 91a33a3402..71347912f1 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,12 @@ # limitations under the License. """Contains the URL paths to prefix various aspects of the server with. """ +import hmac +from hashlib import sha256 + +from six.moves.urllib.parse import urlencode + +from synapse.config import ConfigError CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" @@ -25,3 +32,46 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" + + +class ConsentURIBuilder(object): + def __init__(self, hs_config): + """ + Args: + hs_config (synapse.config.homeserver.HomeServerConfig): + """ + if hs_config.form_secret is None: + raise ConfigError( + "form_secret not set in config", + ) + if hs_config.public_baseurl is None: + raise ConfigError( + "public_baseurl not set in config", + ) + + self._hmac_secret = hs_config.form_secret.encode("utf-8") + self._public_baseurl = hs_config.public_baseurl + + def build_user_consent_uri(self, user_id): + """Build a URI which we can give to the user to do their privacy + policy consent + + Args: + user_id (str): mxid or username of user + + Returns + (str) the URI where the user can do consent + """ + mac = hmac.new( + key=self._hmac_secret, + msg=user_id, + digestmod=sha256, + ).hexdigest() + consent_uri = "%s_matrix/consent?%s" % ( + self._public_baseurl, + urlencode({ + "u": user_id, + "h": mac + }), + ) + return consent_uri diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 9c2b627590..3b6b9368b8 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -14,9 +14,11 @@ # limitations under the License. import sys + +from synapse import python_dependencies # noqa: E402 + sys.dont_write_bytecode = True -from synapse import python_dependencies # noqa: E402 try: python_dependencies.check_requirements() diff --git a/synapse/app/_base.py b/synapse/app/_base.py new file mode 100644 index 0000000000..391bd14c5c --- /dev/null +++ b/synapse/app/_base.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import sys + +from daemonize import Daemonize + +from twisted.internet import error, reactor + +from synapse.util import PreserveLoggingContext +from synapse.util.rlimit import change_resource_limit + +try: + import affinity +except Exception: + affinity = None + + +logger = logging.getLogger(__name__) + + +def start_worker_reactor(appname, config): + """ Run the reactor in the main process + + Daemonizes if necessary, and then configures some resources, before starting + the reactor. Pulls configuration from the 'worker' settings in 'config'. + + Args: + appname (str): application name which will be sent to syslog + config (synapse.config.Config): config object + """ + + logger = logging.getLogger(config.worker_app) + + start_reactor( + appname, + config.soft_file_limit, + config.gc_thresholds, + config.worker_pid_file, + config.worker_daemonize, + config.worker_cpu_affinity, + logger, + ) + + +def start_reactor( + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + cpu_affinity, + logger, +): + """ Run the reactor in the main process + + Daemonizes if necessary, and then configures some resources, before starting + the reactor + + Args: + appname (str): application name which will be sent to syslog + soft_file_limit (int): + gc_thresholds: + pid_file (str): name of pid file to write to if daemonize is True + daemonize (bool): true to run the reactor in a background process + cpu_affinity (int|None): cpu affinity mask + logger (logging.Logger): logger instance to pass to Daemonize + """ + + def run(): + # make sure that we run the reactor with the sentinel log context, + # otherwise other PreserveLoggingContext instances will get confused + # and complain when they see the logcontext arbitrarily swapping + # between the sentinel and `run` logcontexts. + with PreserveLoggingContext(): + logger.info("Running") + if cpu_affinity is not None: + if not affinity: + quit_with_error( + "Missing package 'affinity' required for cpu_affinity\n" + "option\n\n" + "Install by running:\n\n" + " pip install affinity\n\n" + ) + logger.info("Setting CPU affinity to %s" % cpu_affinity) + affinity.set_process_affinity_mask(0, cpu_affinity) + change_resource_limit(soft_file_limit) + if gc_thresholds: + gc.set_threshold(*gc_thresholds) + reactor.run() + + if daemonize: + daemon = Daemonize( + app=appname, + pid=pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +def quit_with_error(error_string): + message_lines = error_string.split("\n") + line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + sys.stderr.write("*" * line_length + '\n') + for line in message_lines: + sys.stderr.write(" %s\n" % (line.rstrip(),)) + sys.stderr.write("*" * line_length + '\n') + sys.exit(1) + + +def listen_metrics(bind_addresses, port): + """ + Start Prometheus metrics server. + """ + from synapse.metrics import RegistryProxy + from prometheus_client import start_http_server + + for host in bind_addresses: + reactor.callInThread(start_http_server, int(port), + addr=host, registry=RegistryProxy) + logger.info("Metrics now reporting on %s:%d", host, port) + + +def listen_tcp(bind_addresses, port, factory, backlog=50): + """ + Create a TCP socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenTCP( + port, + factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): + """ + Create an SSL socket for a port and several addresses + """ + for address in bind_addresses: + try: + reactor.listenSSL( + port, + factory, + context_factory, + backlog, + address + ) + except error.CannotListenError as e: + check_bind_error(e, address, bind_addresses) + + +def check_bind_error(e, address, bind_addresses): + """ + This method checks an exception occurred while binding on 0.0.0.0. + If :: is specified in the bind addresses a warning is shown. + The exception is still raised otherwise. + + Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS + because :: binds on both IPv4 and IPv6 (as per RFC 3493). + When binding on 0.0.0.0 after :: this can safely be ignored. + + Args: + e (Exception): Exception that was caught. + address (str): Address on which binding was attempted. + bind_addresses (list): Addresses on which the service listens. + """ + if address == '0.0.0.0' and '::' in bind_addresses: + logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + else: + raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 9a476efa63..9a37384fb7 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -13,38 +13,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource -from synapse.server import HomeServer +import synapse +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse import events - -from twisted.internet import reactor -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc - logger = logging.getLogger("synapse.app.appservice") @@ -56,19 +51,6 @@ class AppserviceSlaveStore( class AppserviceServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) @@ -82,21 +64,21 @@ class AppserviceServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) - - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse appservice now listening on port %d", port) @@ -105,18 +87,22 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -136,9 +122,14 @@ class ASReplicationHandler(ReplicationClientHandler): if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() - preserve_fn( - self.appservice_handler.notify_interested_services - )(max_stream_id) + run_in_background(self._notify_app_services, max_stream_id) + + @defer.inlineCallbacks + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") def start(config_options): @@ -181,36 +172,13 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-appservice", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-appservice", config) if __name__ == '__main__': diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 09bc1935f1..398bb36602 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -13,46 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import reactor +from twisted.web.resource import NoResource +import synapse +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite +from synapse.crypto import context_factory from synapse.http.server import JsonResource -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.room import PublicRoomListRestServlet +from synapse.rest.client.v1.room import ( + JoinedRoomMemberListRestServlet, + PublicRoomListRestServlet, + RoomEventContextServlet, + RoomMemberListRestServlet, + RoomStateRestServlet, +) from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc logger = logging.getLogger("synapse.app.client_reader") @@ -72,19 +72,6 @@ class ClientReaderSlavedStore( class ClientReaderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) @@ -98,10 +85,16 @@ class ClientReaderServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": resource = JsonResource(self, canonical_json=False) + PublicRoomListRestServlet(self).register(resource) + RoomMemberListRestServlet(self).register(resource) + JoinedRoomMemberListRestServlet(self).register(resource) + RoomStateRestServlet(self).register(resource) + RoomEventContextServlet(self).register(resource) + resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, @@ -109,19 +102,19 @@ class ClientReaderServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse client reader now listening on port %d", port) @@ -130,18 +123,22 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -180,39 +177,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-client-reader", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-client-reader", config) if __name__ == '__main__': diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py new file mode 100644 index 0000000000..374f115644 --- /dev/null +++ b/synapse/app/event_creator.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys + +from twisted.internet import reactor +from twisted.web.resource import NoResource + +import synapse +from synapse import events +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.profile import SlavedProfileStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.room import ( + JoinRoomAliasServlet, + RoomMembershipRestServlet, + RoomSendEventRestServlet, + RoomStateEventRestServlet, +) +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("synapse.app.event_creator") + + +class EventCreatorSlavedStore( + DirectoryStore, + TransactionStore, + SlavedProfileStore, + SlavedAccountDataStore, + SlavedPusherStore, + SlavedReceiptsStore, + SlavedPushRuleStore, + SlavedDeviceStore, + SlavedClientIpStore, + SlavedApplicationServiceStore, + SlavedEventStore, + SlavedRegistrationStore, + RoomStore, + BaseSlavedStore, +): + pass + + +class EventCreatorServer(HomeServer): + def setup(self): + logger.info("Setting up.") + self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + RoomSendEventRestServlet(self).register(resource) + RoomMembershipRestServlet(self).register(resource) + RoomStateEventRestServlet(self).register(resource) + JoinRoomAliasServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, + ) + ) + + logger.info("Synapse event creator now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse event creator", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.event_creator" + + assert config.worker_replication_http_port is not None + + setup_logging(config, use_worker_options=True) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = EventCreatorServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.start_listening(config.worker_listeners) + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + + reactor.callWhenRunning(start) + + _base.start_worker_reactor("synapse-event-creator", config) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index eb392e1c9d..7af00b8bcf 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -13,43 +13,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import reactor +from twisted.web.resource import NoResource +import synapse +from synapse import events +from synapse.api.urls import FEDERATION_PREFIX +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.federation.transport.server import TransportLayerServer from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore -from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.api.urls import FEDERATION_PREFIX -from synapse.federation.transport.server import TransportLayerServer -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc logger = logging.getLogger("synapse.app.federation_reader") @@ -66,19 +60,6 @@ class FederationReaderSlavedStore( class FederationReaderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) @@ -92,25 +73,25 @@ class FederationReaderServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "federation": resources.update({ FEDERATION_PREFIX: TransportLayerServer(self), }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse federation reader now listening on port %d", port) @@ -119,18 +100,22 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -169,39 +154,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-federation-reader", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-federation-reader", config) if __name__ == '__main__': diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 03327dc47a..18469013fa 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -13,44 +13,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource -from synapse.server import HomeServer +import synapse +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.crypto import context_factory -from synapse.http.site import SynapseSite from synapse.federation import send_queue -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.transactions import TransactionStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse import events - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc - logger = logging.getLogger("synapse.app.federation_sender") @@ -83,19 +78,6 @@ class FederationSenderSlaveStore( class FederationSenderServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) @@ -109,21 +91,21 @@ class FederationSenderServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) - - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -132,18 +114,22 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -213,36 +199,12 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-federation-sender", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-federation-sender", config) class FederationSenderHandler(object): @@ -277,7 +239,7 @@ class FederationSenderHandler(object): # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) - preserve_fn(self.update_token)(token) + run_in_background(self.update_token, token) # We also need to poke the federation sender when new events happen elif stream_name == "events": @@ -285,19 +247,22 @@ class FederationSenderHandler(object): @defer.inlineCallbacks def update_token(self, token): - self.federation_position = token - - # We linearize here to ensure we don't have races updating the token - with (yield self._fed_position_linearizer.queue(None)): - if self._last_ack < self.federation_position: - yield self.store.update_federation_out_pos( - "federation", self.federation_position - ) + try: + self.federation_position = token + + # We linearize here to ensure we don't have races updating the token + with (yield self._fed_position_linearizer.queue(None)): + if self._last_ack < self.federation_position: + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) - # We ACK this token over replication so that the master can drop - # its in memory queues - self.replication_client.send_federation_ack(self.federation_position) - self._last_ack = self.federation_position + # We ACK this token over replication so that the master can drop + # its in memory queues + self.replication_client.send_federation_ack(self.federation_position) + self._last_ack = self.federation_position + except Exception: + logger.exception("Error updating federation stream position") if __name__ == '__main__': diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py new file mode 100644 index 0000000000..b5f78f4640 --- /dev/null +++ b/synapse/app/frontend_proxy.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys + +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + +import synapse +from synapse import events +from synapse.api.errors import SynapseError +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.crypto import context_factory +from synapse.http.server import JsonResource +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v2_alpha._base import client_v2_patterns +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("synapse.app.frontend_proxy") + + +class KeyUploadServlet(RestServlet): + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(KeyUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.http_client = hs.get_simple_http_client() + self.main_uri = hs.config.worker_main_http_uri + + @defer.inlineCallbacks + def on_POST(self, request, device_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if (requester.device_id is not None and + device_id != requester.device_id): + logger.warning("Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, device_id) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, + "To upload keys, you must pass device_id when authenticating" + ) + + if body: + # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) + headers = { + "Authorization": auth_headers, + } + result = yield self.http_client.post_json_get_json( + self.main_uri + request.uri, + body, + headers=headers, + ) + + defer.returnValue((200, result)) + else: + # Just interested in counts. + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + defer.returnValue((200, {"one_time_key_counts": result})) + + +class FrontendProxySlavedStore( + SlavedDeviceStore, + SlavedClientIpStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, +): + pass + + +class FrontendProxyServer(HomeServer): + def setup(self): + logger.info("Setting up.") + self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + KeyUploadServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, + ) + ) + + logger.info("Synapse client reader now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse frontend proxy", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.frontend_proxy" + + assert config.worker_main_http_uri is not None + + setup_logging(config, use_worker_options=True) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = FrontendProxyServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.start_listening(config.worker_listeners) + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + + reactor.callWhenRunning(start) + + _base.start_worker_reactor("synapse-frontend-proxy", config) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 081e7cce59..2ad1beb8d8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -13,61 +13,62 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import synapse - import gc import logging import os import sys -import synapse.config.logger -from synapse.config._base import ConfigError - -from synapse.python_dependencies import ( - check_requirements, CONDITIONAL_REQUIREMENTS -) - -from synapse.rest import ClientRestResource -from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database - -from synapse.server import HomeServer +from six import iteritems -from twisted.internet import reactor, defer from twisted.application import service -from twisted.web.resource import Resource, EncodingResourceWrapper -from twisted.web.static import File +from twisted.internet import defer, reactor +from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.server import GzipEncoderFactory -from synapse.http.server import RootRedirect -from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource -from synapse.rest.key.v1.server_key_resource import LocalKey -from synapse.rest.key.v2 import KeyApiV2Resource +from twisted.web.static import File + +import synapse +import synapse.config.logger +from synapse import events from synapse.api.urls import ( - FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, + CONTENT_REPO_PREFIX, + FEDERATION_PREFIX, + LEGACY_MEDIA_PREFIX, + MEDIA_PREFIX, + SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, + STATIC_PREFIX, + WEB_CLIENT_PREFIX, ) +from synapse.app import _base +from synapse.app._base import listen_ssl, listen_tcp, quit_with_error +from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.metrics import register_memory_metrics -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.federation.transport.server import TransportLayerServer - -from synapse.util.rlimit import change_resource_limit -from synapse.util.versionstring import get_version_string +from synapse.http.additional_resource import AdditionalResource +from synapse.http.server import RootRedirect +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.module_api import ModuleApi +from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements +from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.rest import ClientRestResource +from synapse.rest.key.v1.server_key_resource import LocalKey +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.media.v0.content_repository import ContentRepoResource +from synapse.server import HomeServer +from synapse.storage import are_all_users_on_domain +from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database +from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole - -from synapse.http.site import SynapseSite - -from synapse import events - -from daemonize import Daemonize +from synapse.util.module_loader import load_module +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") @@ -119,87 +120,132 @@ class SynapseHomeServer(HomeServer): resources = {} for res in listener_config["resources"]: for name in res["names"]: - if name == "client": - client_resource = ClientRestResource(self) - if res["compress"]: - client_resource = gz_wrap(client_resource) - - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - }) - - if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) - - if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) - - if name in ["media", "federation", "client"]: - media_repo = MediaRepositoryResource(self) - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) - - if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) - - if name == "webclient": - resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) - - if name == "metrics" and self.get_config().enable_metrics: - resources[METRICS_PREFIX] = MetricsResource(self) + resources.update(self._configure_named_resource( + name, res.get("compress", False), + )) + + additional_resources = listener_config.get("additional_resources", {}) + logger.debug("Configuring additional resources: %r", + additional_resources) + module_api = ModuleApi(self, self.get_auth_handler()) + for path, resmodule in additional_resources.items(): + handler_cls, config = load_module(resmodule) + handler = handler_cls(config, module_api) + resources[path] = AdditionalResource(self, handler.handle_request) if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) else: - root_resource = Resource() + root_resource = NoResource() root_resource = create_resource_tree(resources, root_resource) if tls: - for address in bind_addresses: - reactor.listenSSL( - port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - self.tls_server_context_factory, - interface=address - ) + listen_ssl( + bind_addresses, + port, + SynapseSite( + "synapse.access.https.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, + ), + self.tls_server_context_factory, + ) + else: - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse now listening on port %d", port) + def _configure_named_resource(self, name, compress=False): + """Build a resource map for a named resource + + Args: + name (str): named resource: one of "client", "federation", etc + compress (bool): whether to enable gzip compression for this + resource + + Returns: + dict[str, Resource]: map from path to HTTP resource + """ + resources = {} + if name == "client": + client_resource = ClientRestResource(self) + if compress: + client_resource = gz_wrap(client_resource) + + resources.update({ + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + }) + + if name == "consent": + from synapse.rest.consent.consent_resource import ConsentResource + consent_resource = ConsentResource(self) + if compress: + consent_resource = gz_wrap(consent_resource) + resources.update({ + "/_matrix/consent": consent_resource, + }) + + if name == "federation": + resources.update({ + FEDERATION_PREFIX: TransportLayerServer(self), + }) + + if name in ["static", "client"]: + resources.update({ + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), + }) + + if name in ["media", "federation", "client"]: + if self.get_config().enable_media_repo: + media_repo = self.get_media_repository_resource() + resources.update({ + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + }) + elif name == "media": + raise ConfigError( + "'media' resource conflicts with enable_media_repo=False", + ) + + if name in ["keys", "federation"]: + resources.update({ + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), + }) + + if name == "webclient": + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) + + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationRestResource(self) + + return resources + def start_listening(self): config = self.get_config() @@ -207,18 +253,15 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) elif listener["type"] == "replication": bind_addresses = listener["bind_addresses"] for address in bind_addresses: @@ -229,6 +272,13 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger( "before", "shutdown", server_listener.stopListening, ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -248,29 +298,6 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - - -def quit_with_error(error_string): - message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') - for line in message_lines: - sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') - sys.exit(1) - def setup(config_options): """ @@ -300,11 +327,6 @@ def setup(config_options): # check any extra requirements we have now we have a config check_requirements(config) - version_string = "Synapse/" + get_version_string(synapse) - - logger.info("Server hostname: %s", config.server_name) - logger.info("Server version: %s", version_string) - events.USE_FROZEN_DICTS = config.use_frozen_dicts tls_server_context_factory = context_factory.ServerContextFactory(config) @@ -317,7 +339,7 @@ def setup(config_options): db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, config=config, - version_string=version_string, + version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, ) @@ -349,9 +371,7 @@ def setup(config_options): hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_replication_layer().start_get_pdu_cache() - - register_memory_metrics(hs) + hs.get_federation_client().start_get_pdu_cache() reactor.callWhenRunning(start) @@ -403,6 +423,10 @@ def run(hs): stats = {} + # Contains the list of processes we will be monitoring + # currently either 0 or 1 + stats_process = [] + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -419,6 +443,10 @@ def run(hs): total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() stats["total_nonbridged_users"] = total_nonbridged_users + daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + for name, count in iteritems(daily_user_type_results): + stats["daily_user_type_" + name] = count + room_count = yield hs.get_datastore().get_room_count() stats["total_room_count"] = room_count @@ -426,8 +454,21 @@ def run(hs): stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in iteritems(r30_results): + stats["r30_users_" + name] = count + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + if len(stats_process) > 0: + stats["memory_rss"] = 0 + stats["cpu_average"] = 0 + for process in stats_process: + stats["memory_rss"] += process.memory_info().rss + stats["cpu_average"] += int(process.cpu_percent(interval=None)) logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: @@ -438,45 +479,56 @@ def run(hs): except Exception as e: logger.warn("Error reporting stats: %s", e) + def performance_stats_init(): + try: + import psutil + process = psutil.Process() + # Ensure we can fetch both, and make the initial request for cpu_percent + # so the next request will use this as the initial point. + process.memory_info().rss + process.cpu_percent(interval=None) + logger.info("report_stats can use psutil") + stats_process.append(process) + except (ImportError, AttributeError): + logger.warn( + "report_stats enabled but psutil is not installed or incorrect version." + " Disabling reporting of memory/cpu stats." + " Ensuring psutil is available will help matrix.org track performance" + " changes across releases." + ) + + def generate_user_daily_visit_stats(): + hs.get_datastore().generate_user_daily_visits() + + # Rather than update on per session basis, batch up the requests. + # If you increase the loop period, the accuracy of user_daily_visits + # table will decrease + clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes clock.call_later(5 * 60, phone_stats_home) - def in_thread(): - # Uncomment to enable tracing of log context changes. - # sys.settrace(logcontext_tracer) - - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - change_resource_limit(hs.config.soft_file_limit) - if hs.config.gc_thresholds: - gc.set_threshold(*hs.config.gc_thresholds) - reactor.run() - - if hs.config.daemonize: - - if hs.config.print_pidfile: - print (hs.config.pid_file) - - daemon = Daemonize( - app="synapse-homeserver", - pid=hs.config.pid_file, - action=lambda: in_thread(), - auto_close_fds=False, - verbose=True, - logger=logger, - ) - - daemon.start() - else: - in_thread() + if hs.config.daemonize and hs.config.print_pidfile: + print (hs.config.pid_file) + + _base.start_reactor( + "synapse-homeserver", + hs.config.soft_file_limit, + hs.config.gc_thresholds, + hs.config.pid_file, + hs.config.daemonize, + hs.config.cpu_affinity, + logger, + ) def main(): diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index f57ec784fe..749bbf37d0 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -13,14 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import reactor +from twisted.web.resource import NoResource +import synapse +from synapse import events +from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.crypto import context_factory from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -28,31 +37,13 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.api.urls import ( - CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX -) -from synapse.crypto import context_factory - -from synapse import events - - -from twisted.internet import reactor -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc logger = logging.getLogger("synapse.app.media_repository") @@ -69,19 +60,6 @@ class MediaRepositorySlavedStore( class MediaRepositoryServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) @@ -95,9 +73,9 @@ class MediaRepositoryServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "media": - media_repo = MediaRepositoryResource(self) + media_repo = self.get_media_repository_resource() resources.update({ MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, @@ -106,19 +84,19 @@ class MediaRepositoryServer(HomeServer): ), }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse media repository now listening on port %d", port) @@ -127,18 +105,22 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -159,6 +141,13 @@ def start(config_options): assert config.worker_app == "synapse.app.media_repository" + if config.enable_media_repo: + _base.quit_with_error( + "enable_media_repo must be disabled in the main synapse process\n" + "before the media repo can be run in a separate worker.\n" + "Please add ``enable_media_repo: false`` to the main config\n" + ) + setup_logging(config, use_worker_options=True) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -177,39 +166,15 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-media-repository", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-media-repository", config) if __name__ == '__main__': diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index f9114acfcb..9295a51d5b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -13,41 +13,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys -import synapse +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource -from synapse.server import HomeServer +import synapse +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.storage.roommember import RoomMemberStore +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.storage.engines import create_engine +from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn, \ - PreserveLoggingContext +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse import events - -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc - logger = logging.getLogger("synapse.app.pusher") @@ -83,25 +76,8 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - class PusherServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = PusherSlaveStore(self.get_db_conn(), self) @@ -118,21 +94,21 @@ class PusherServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) - - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse pusher now listening on port %d", port) @@ -141,18 +117,22 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -170,24 +150,27 @@ class PusherReplicationHandler(ReplicationClientHandler): def on_rdata(self, stream_name, token, rows): super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) - preserve_fn(self.poke_pushers)(stream_name, token, rows) + run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks def poke_pushers(self, stream_name, token, rows): - if stream_name == "pushers": - for row in rows: - if row.deleted: - yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) - else: - yield self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( - token, token, - ) - elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) - ) + try: + if stream_name == "pushers": + for row in rows: + if row.deleted: + yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + yield self.start_pusher(row.user_id, row.app_id, row.pushkey) + elif stream_name == "events": + yield self.pusher_pool.on_new_notifications( + token, token, + ) + elif stream_name == "receipts": + yield self.pusher_pool.on_new_receipts( + token, token, set(row.room_id for row in rows) + ) + except Exception: + logger.exception("Error poking pushers") def stop_pusher(self, user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) @@ -244,18 +227,6 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ps.get_pusherpool().start() ps.get_datastore().start_profiling() @@ -263,18 +234,7 @@ def start(config_options): reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-pusher", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-pusher", config) if __name__ == '__main__': diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 4bdd99a966..26b9ec85f2 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -13,78 +13,74 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import logging +import sys -import synapse +from six import iteritems + +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource +import synapse from synapse.api.constants import EventTypes +from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.handlers.presence import PresenceHandler, get_interested_parties -from synapse.http.site import SynapseSite from synapse.http.server import JsonResource -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX -from synapse.rest.client.v2_alpha import sync -from synapse.rest.client.v1 import events -from synapse.rest.client.v1.room import RoomInitialSyncRestServlet -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1 import events +from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.rest.client.v1.room import RoomInitialSyncRestServlet +from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState from synapse.storage.roommember import RoomMemberStore from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string -from twisted.internet import reactor, defer -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import contextlib -import gc - logger = logging.getLogger("synapse.app.synchrotron") class SynchrotronSlavedStore( - SlavedPushRuleStore, - SlavedEventStore, SlavedReceiptsStore, SlavedAccountDataStore, SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedFilteringStore, SlavedPresenceStore, + SlavedGroupServerStore, SlavedDeviceInboxStore, SlavedDeviceStore, + SlavedPushRuleStore, + SlavedEventStore, SlavedClientIpStore, RoomStore, BaseSlavedStore, ): - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - did_forget = ( RoomMemberStore.__dict__["did_forget"] ) @@ -219,7 +215,7 @@ class SynchrotronPresence(object): def get_currently_syncing_users(self): return [ - user_id for user_id, count in self.user_to_num_current_syncs.iteritems() + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0 ] @@ -250,19 +246,6 @@ class SynchrotronApplicationService(object): class SynchrotronServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) @@ -276,7 +259,7 @@ class SynchrotronServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": resource = JsonResource(self, canonical_json=False) sync.register_servlets(self, resource) @@ -290,19 +273,19 @@ class SynchrotronServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -311,18 +294,22 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -344,15 +331,13 @@ class SyncReplicationHandler(ReplicationClientHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() + # NB this is a SynchrotronPresence, not a normal PresenceHandler self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - self.presence_handler.sync_callback = self.send_user_sync - def on_rdata(self, stream_name, token, rows): super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) - - preserve_fn(self.process_and_notify)(stream_name, token, rows) + run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): args = super(SyncReplicationHandler, self).get_streams_to_replicate() @@ -364,51 +349,58 @@ class SyncReplicationHandler(ReplicationClientHandler): @defer.inlineCallbacks def process_and_notify(self, stream_name, token, rows): - if stream_name == "events": - # We shouldn't get multiple rows per token for events stream, so - # we don't need to optimise this for multiple rows. - for row in rows: - event = yield self.store.get_event(row.event_id) - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event( - event, token, max_token, extra_users + try: + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], ) - elif stream_name == "push_rules": - self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows], - ) - elif stream_name in ("account_data", "tag_account_data",): - self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows], - ) - elif stream_name == "receipts": - self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows], - ) - elif stream_name == "typing": - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows], - ) - elif stream_name == "to_device": - entities = [row.entity for row in rows if row.entity.startswith("@")] - if entities: + elif stream_name in ("account_data", "tag_account_data",): self.notifier.on_new_event( - "to_device_key", token, users=entities, + "account_data_key", token, users=[row.user_id for row in rows], ) - elif stream_name == "device_lists": - all_room_ids = set() - for row in rows: - room_ids = yield self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) - self.notifier.on_new_event( - "device_list_key", token, rooms=all_room_ids, - ) - elif stream_name == "presence": - yield self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, + ) + elif stream_name == "presence": + yield self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows], + ) + except Exception: + logger.exception("Error processing replication") def start(config_options): @@ -440,36 +432,13 @@ def start(config_options): ss.setup() ss.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ss.get_datastore().start_profiling() ss.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-synchrotron", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-synchrotron", config) if __name__ == '__main__': diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 3bd7ef7bba..d658f967ba 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -16,16 +16,19 @@ import argparse import collections +import errno import glob import os import os.path import signal import subprocess import sys -import yaml -import errno import time +from six import iteritems + +import yaml + SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] GREEN = "\x1b[1;32m" @@ -38,7 +41,7 @@ def pid_running(pid): try: os.kill(pid, 0) return True - except OSError, err: + except OSError as err: if err.errno == errno.EPERM: return True return False @@ -98,7 +101,7 @@ def stop(pidfile, app): try: os.kill(pid, signal.SIGTERM) write("stopped %s" % (app,), colour=GREEN) - except OSError, err: + except OSError as err: if err.errno == errno.ESRCH: write("%s not running" % (app,), colour=YELLOW) elif err.errno == errno.EPERM: @@ -171,6 +174,10 @@ def main(): if cache_factor: os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) + cache_factors = config.get("synctl_cache_factors", {}) + for cache_name, factor in iteritems(cache_factors): + os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) + worker_configfiles = [] if options.worker: start_stop_synapse = False @@ -184,6 +191,9 @@ def main(): worker_configfiles.append(worker_configfile) if options.all_processes: + # To start the main synapse with -a you need to add a worker file + # with worker_app == "synapse.app.homeserver" + start_stop_synapse = False worker_configdir = options.all_processes if not os.path.isdir(worker_configdir): write( @@ -200,11 +210,29 @@ def main(): with open(worker_configfile) as stream: worker_config = yaml.load(stream) worker_app = worker_config["worker_app"] - worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, "worker_daemonize") - worker_cache_factor = worker_config.get("synctl_cache_factor") + if worker_app == "synapse.app.homeserver": + # We need to special case all of this to pick up options that may + # be set in the main config file or in this worker config file. + worker_pidfile = ( + worker_config.get("pid_file") + or pidfile + ) + worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor + daemonize = worker_config.get("daemonize") or config.get("daemonize") + assert daemonize, "Main process must have daemonize set to true" + + # The master process doesn't support using worker_* config. + for key in worker_config: + if key == "worker_app": # But we allow worker_app + continue + assert not key.startswith("worker_"), \ + "Main process cannot use worker_* config" + else: + worker_pidfile = worker_config["worker_pid_file"] + worker_daemonize = worker_config["worker_daemonize"] + assert worker_daemonize, "In config %r: expected '%s' to be True" % ( + worker_configfile, "worker_daemonize") + worker_cache_factor = worker_config.get("synctl_cache_factor") workers.append(Worker( worker_app, worker_configfile, worker_pidfile, worker_cache_factor, )) @@ -231,6 +259,7 @@ def main(): for running_pid in running_pids: while pid_running(running_pid): time.sleep(0.2) + write("All processes exited; now restarting...") if action == "start" or action == "restart": if start_stop_synapse: diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 8c6300db9d..637a89530a 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -14,16 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse +import logging +import sys -from synapse.server import HomeServer +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + +import synapse +from synapse import events +from synapse.app import _base from synapse.config._base import ConfigError -from synapse.config.logger import setup_logging from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging from synapse.crypto import context_factory -from synapse.http.site import SynapseSite from synapse.http.server import JsonResource -from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.http.site import SynapseSite +from synapse.metrics import RegistryProxy +from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -31,25 +38,14 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha import user_directory +from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.user_directory import UserDirectoryStore +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from synapse import events - -from twisted.internet import reactor -from twisted.web.resource import Resource - -from daemonize import Daemonize - -import sys -import logging -import gc logger = logging.getLogger("synapse.app.user_dir") @@ -98,19 +94,6 @@ class UserDirectorySlaveStore( class UserDirectoryServer(HomeServer): - def get_db_conn(self, run_new_connection=True): - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def setup(self): logger.info("Setting up.") self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) @@ -124,7 +107,7 @@ class UserDirectoryServer(HomeServer): for res in listener_config["resources"]: for name in res["names"]: if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(self) + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": resource = JsonResource(self, canonical_json=False) user_directory.register_servlets(self, resource) @@ -135,19 +118,19 @@ class UserDirectoryServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) - - for address in bind_addresses: - reactor.listenTCP( - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - ), - interface=address + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, ) + ) logger.info("Synapse user_dir now listening on port %d", port) @@ -156,18 +139,22 @@ class UserDirectoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_addresses = listener["bind_addresses"] - - for address in bind_addresses: - reactor.listenTCP( - listener["port"], - manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ), - interface=address + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, ) + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warn(("Metrics listener configured, but " + "enable_metrics is not True!")) + else: + _base.listen_metrics(listener["bind_addresses"], + listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -187,7 +174,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): stream_name, token, rows ) if stream_name == "current_state_deltas": - preserve_fn(self.user_directory.notify_new_event)() + run_in_background(self._notify_directory) + + @defer.inlineCallbacks + def _notify_directory(self): + try: + yield self.user_directory.notify_new_event() + except Exception: + logger.exception("Error notifiying user directory of state update") def start(config_options): @@ -233,36 +227,13 @@ def start(config_options): ps.setup() ps.start_listening(config.worker_listeners) - def run(): - # make sure that we run the reactor with the sentinel log context, - # otherwise other PreserveLoggingContext instances will get confused - # and complain when they see the logcontext arbitrarily swapping - # between the sentinel and `run` logcontexts. - with PreserveLoggingContext(): - logger.info("Running") - change_resource_limit(config.soft_file_limit) - if config.gc_thresholds: - gc.set_threshold(*config.gc_thresholds) - reactor.run() - def start(): ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() reactor.callWhenRunning(start) - if config.worker_daemonize: - daemon = Daemonize( - app="synapse-user-dir", - pid=config.worker_pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + _base.start_worker_reactor("synapse-user-dir", config) if __name__ == '__main__': diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b989007314..57ed8a3ca2 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes -from synapse.util.caches.descriptors import cachedInlineCallbacks +import logging +import re + +from six import string_types from twisted.internet import defer -import logging -import re +from synapse.api.constants import EventTypes +from synapse.types import GroupID, get_domain_from_id +from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -81,14 +84,17 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, url=None, namespaces=None, hs_token=None, - sender=None, id=None, protocols=None, rate_limited=True): + def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, + sender=None, id=None, protocols=None, rate_limited=True, + ip_range_whitelist=None): self.token = token self.url = url self.hs_token = hs_token self.sender = sender + self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id + self.ip_range_whitelist = ip_range_whitelist if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -125,8 +131,26 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) + group_id = regex_obj.get("group_id") + if group_id: + if not isinstance(group_id, str): + raise ValueError( + "Expected string for 'group_id' in ns '%s'" % ns + ) + try: + GroupID.from_string(group_id) + except Exception: + raise ValueError( + "Expected valid group ID for 'group_id' in ns '%s'" % ns + ) + + if get_domain_from_id(group_id) != self.server_name: + raise ValueError( + "Expected 'group_id' to be this host in ns '%s'" % ns + ) + regex = regex_obj.get("regex") - if isinstance(regex, basestring): + if isinstance(regex, string_types): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: raise ValueError( @@ -251,8 +275,27 @@ class ApplicationService(object): if regex_obj["exclusive"] ] + def get_groups_for_user(self, user_id): + """Get the groups that this user is associated with by this AS + + Args: + user_id (str): The ID of the user. + + Returns: + iterable[str]: an iterable that yields group_id strings. + """ + return ( + regex_obj["group_id"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + ) + def is_rate_limited(self): return self.rate_limited def __str__(self): - return "ApplicationService: %s" % (self.__dict__,) + # copy dictionary and redact token fields so they don't get logged + dict_copy = self.__dict__.copy() + dict_copy["token"] = "<redacted>" + dict_copy["hs_token"] = "<redacted>" + return "ApplicationService: %s" % (dict_copy,) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6893610e71..6980e5890e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -12,20 +12,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import urllib + +from prometheus_client import Counter + from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException -from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event -from synapse.util.caches.response_cache import ResponseCache +from synapse.http.client import SimpleHttpClient from synapse.types import ThirdPartyInstanceID - -import logging -import urllib +from synapse.util.caches.response_cache import ResponseCache logger = logging.getLogger(__name__) +sent_transactions_counter = Counter( + "synapse_appservice_api_sent_transactions", + "Number of /transactions/ requests sent", + ["service"] +) + +failed_transactions_counter = Counter( + "synapse_appservice_api_failed_transactions", + "Number of /transactions/ requests that failed to send", + ["service"] +) + +sent_events_counter = Counter( + "synapse_appservice_api_sent_events", + "Number of events sent to the AS", + ["service"] +) HOUR_IN_MS = 60 * 60 * 1000 @@ -72,7 +91,8 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", + timeout_ms=HOUR_IN_MS) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -192,9 +212,7 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - return self.protocol_meta_cache.get(key) or ( - self.protocol_meta_cache.set(key, _get()) - ) + return self.protocol_meta_cache.wrap(key, _get) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): @@ -220,12 +238,15 @@ class ApplicationServiceApi(SimpleHttpClient): args={ "access_token": service.hs_token }) + sent_transactions_counter.labels(service.id).inc() + sent_events_counter.labels(service.id).inc(len(events)) defer.returnValue(True) return except CodeMessageException as e: logger.warning("push_bulk to %s received %s", uri, e.code) except Exception as ex: logger.warning("push_bulk to %s threw exception %s", uri, ex) + failed_transactions_counter.labels(service.id).inc() defer.returnValue(False) def _serialize(self, events): diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 68a9de17b8..2430814796 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,14 +48,14 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ +import logging + from twisted.internet import defer from synapse.appservice import ApplicationServiceState -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure -import logging - logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class _ServiceQueuer(object): def enqueue(self, service, event): # if this service isn't being sent something self.queued_events.setdefault(service.id, []).append(event) - preserve_fn(self._send_request)(service) + run_in_background(self._send_request, service) @defer.inlineCallbacks def _send_request(self, service): @@ -123,7 +123,7 @@ class _ServiceQueuer(object): with Measure(self.clock, "servicequeuer.send"): try: yield self.txn_ctrl.send(service, events) - except: + except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) @@ -152,10 +152,10 @@ class _TransactionController(object): if sent: yield txn.complete(self.store) else: - preserve_fn(self._start_recoverer)(service) - except Exception as e: - logger.exception(e) - preserve_fn(self._start_recoverer)(service) + run_in_background(self._start_recoverer, service) + except Exception: + logger.exception("Error creating appservice transaction") + run_in_background(self._start_recoverer, service) @defer.inlineCallbacks def on_recovered(self, recoverer): @@ -176,17 +176,20 @@ class _TransactionController(object): @defer.inlineCallbacks def _start_recoverer(self, service): - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) - logger.info( - "Application service falling behind. Starting recoverer. AS ID %s", - service.id - ) - recoverer = self.recoverer_fn(service, self.on_recovered) - self.add_recoverers([recoverer]) - recoverer.recover() + try: + yield self.store.set_appservice_state( + service, + ApplicationServiceState.DOWN + ) + logger.info( + "Application service falling behind. Starting recoverer. AS ID %s", + service.id + ) + recoverer = self.recoverer_fn(service, self.on_recovered) + self.add_recoverers([recoverer]) + recoverer.recover() + except Exception: + logger.exception("Error starting AS recoverer") @defer.inlineCallbacks def _is_service_up(self, service): diff --git a/synapse/config/__init__.py b/synapse/config/__init__.py index bfebb0f644..f2a5a41e92 100644 --- a/synapse/config/__init__.py +++ b/synapse/config/__init__.py @@ -12,3 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._base import ConfigError + +# export ConfigError if somebody does import * +# this is largely a fudge to stop PEP8 moaning about the import +__all__ = ["ConfigError"] diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..3d2e90dd5b 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -16,9 +16,12 @@ import argparse import errno import os -import yaml from textwrap import dedent +from six import integer_types + +import yaml + class ConfigError(Exception): pass @@ -49,7 +52,7 @@ Missing mandatory `server_name` config option. class Config(object): @staticmethod def parse_size(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -61,7 +64,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value second = 1000 minute = 60 * second @@ -82,21 +85,37 @@ class Config(object): return os.path.abspath(file_path) if file_path else file_path @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - "<server name> will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +267,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,33 +280,33 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) - with open(config_path, "wb") as config_file: - config_bytes, config = obj.generate_config( + with open(config_path, "w") as config_file: + config_str, config = obj.generate_config( config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), is_generating_file=True ) obj.invoke_all("generate_files", config) - config_file.write(config_bytes) - print ( + config_file.write(config_str) + print(( "A config file has been generated in %r for server name" " %r with corresponding SSL keys and self-signed" " certificates. Please review this file and customise it" " to your needs." - ) % (config_path, server_name) - print ( + ) % (config_path, server_name)) + print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print ( + print(( "Config file %r already exists. Generating any missing key" " files." - ) % (config_path,) + ) % (config_path,)) generate_keys = True parser = argparse.ArgumentParser( diff --git a/synapse/config/api.py b/synapse/config/api.py index 20ba33226a..403d96ba76 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - from synapse.api.constants import EventTypes +from ._base import Config + class ApiConfig(Config): diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 82c50b8240..3b161d708a 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +import logging + +from six import string_types +from six.moves.urllib import parse as urlparse + +import yaml +from netaddr import IPSet from synapse.appservice import ApplicationService from synapse.types import UserID -import urllib -import yaml -import logging +from ._base import Config, ConfigError logger = logging.getLogger(__name__) @@ -89,21 +93,21 @@ def _load_appservice(hostname, as_info, config_filename): "id", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: - if not isinstance(as_info.get(field), basestring): + if not isinstance(as_info.get(field), string_types): raise KeyError("Required string field: '%s' (%s)" % ( field, config_filename, )) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), basestring) and + if (not isinstance(as_info.get("url"), string_types) and as_info.get("url", "") is not None): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] - if urllib.quote(localpart) != localpart: + if urlparse.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) @@ -128,7 +132,7 @@ def _load_appservice(hostname, as_info, config_filename): "Expected namespace entry in %s to be an object," " but got %s", ns, regex_obj ) - if not isinstance(regex_obj.get("regex"), basestring): + if not isinstance(regex_obj.get("regex"), string_types): raise ValueError( "Missing/bad type 'regex' key in %s", regex_obj ) @@ -152,13 +156,22 @@ def _load_appservice(hostname, as_info, config_filename): " will not receive events or queries.", config_filename, ) + + ip_range_whitelist = None + if as_info.get('ip_range_whitelist'): + ip_range_whitelist = IPSet( + as_info.get('ip_range_whitelist') + ) + return ApplicationService( token=as_info["as_token"], + hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], protocols=protocols, - rate_limited=rate_limited + rate_limited=rate_limited, + ip_range_whitelist=ip_range_whitelist, ) diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 938f6f25f8..8109e5f95e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -41,7 +41,7 @@ class CasConfig(Config): #cas_config: # enabled: true # server_url: "https://cas-server.com" - # service_url: "https://homesever.domain.com:8448" + # service_url: "https://homeserver.domain.com:8448" # #required_attributes: # # name: value """ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py new file mode 100644 index 0000000000..e22c731aad --- /dev/null +++ b/synapse/config/consent_config.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + +DEFAULT_CONFIG = """\ +# User Consent configuration +# +# for detailed instructions, see +# https://github.com/matrix-org/synapse/blob/master/docs/consent_tracking.md +# +# Parts of this section are required if enabling the 'consent' resource under +# 'listeners', in particular 'template_dir' and 'version'. +# +# 'template_dir' gives the location of the templates for the HTML forms. +# This directory should contain one subdirectory per language (eg, 'en', 'fr'), +# and each language directory should contain the policy document (named as +# '<version>.html') and a success page (success.html). +# +# 'version' specifies the 'current' version of the policy document. It defines +# the version to be served by the consent resource if there is no 'v' +# parameter. +# +# 'server_notice_content', if enabled, will send a user a "Server Notice" +# asking them to consent to the privacy policy. The 'server_notices' section +# must also be configured for this to work. Notices will *not* be sent to +# guest users unless 'send_server_notice_to_guests' is set to true. +# +# 'block_events_error', if set, will block any attempts to send events +# until the user consents to the privacy policy. The value of the setting is +# used as the text of the error. +# +# user_consent: +# template_dir: res/templates/privacy +# version: 1.0 +# server_notice_content: +# msgtype: m.text +# body: >- +# To continue using this homeserver you must review and agree to the +# terms and conditions at %(consent_uri)s +# send_server_notice_to_guests: True +# block_events_error: >- +# To continue using this homeserver you must review and agree to the +# terms and conditions at %(consent_uri)s +# +""" + + +class ConsentConfig(Config): + def __init__(self): + super(ConsentConfig, self).__init__() + + self.user_consent_version = None + self.user_consent_template_dir = None + self.user_consent_server_notice_content = None + self.user_consent_server_notice_to_guests = False + self.block_events_without_consent_error = None + + def read_config(self, config): + consent_config = config.get("user_consent") + if consent_config is None: + return + self.user_consent_version = str(consent_config["version"]) + self.user_consent_template_dir = consent_config["template_dir"] + self.user_consent_server_notice_content = consent_config.get( + "server_notice_content", + ) + self.block_events_without_consent_error = consent_config.get( + "block_events_error", + ) + self.user_consent_server_notice_to_guests = bool(consent_config.get( + "send_server_notice_to_guests", False, + )) + + def default_config(self, **kwargs): + return DEFAULT_CONFIG diff --git a/synapse/config/groups.py b/synapse/config/groups.py new file mode 100644 index 0000000000..997fa2881f --- /dev/null +++ b/synapse/config/groups.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class GroupsConfig(Config): + def read_config(self, config): + self.enable_group_creation = config.get("enable_group_creation", False) + self.group_creation_prefix = config.get("group_creation_prefix", "") + + def default_config(self, **kwargs): + return """\ + # Whether to allow non server admins to create groups on this server + enable_group_creation: false + + # If enabled, non server admins can only create groups with local parts + # starting with this prefix + # group_creation_prefix: "unofficial/" + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b22cacf8dc..2fd9c48abf 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,28 +13,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .tls import TlsConfig -from .server import ServerConfig -from .logger import LoggingConfig -from .database import DatabaseConfig -from .ratelimiting import RatelimitConfig -from .repository import ContentRepositoryConfig -from .captcha import CaptchaConfig -from .voip import VoipConfig -from .registration import RegistrationConfig -from .metrics import MetricsConfig from .api import ApiConfig from .appservice import AppServiceConfig -from .key import KeyConfig -from .saml2 import SAML2Config +from .captcha import CaptchaConfig from .cas import CasConfig -from .password import PasswordConfig +from .consent_config import ConsentConfig +from .database import DatabaseConfig +from .emailconfig import EmailConfig +from .groups import GroupsConfig from .jwt import JWTConfig +from .key import KeyConfig +from .logger import LoggingConfig +from .metrics import MetricsConfig +from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig -from .emailconfig import EmailConfig -from .workers import WorkerConfig from .push import PushConfig +from .ratelimiting import RatelimitConfig +from .registration import RegistrationConfig +from .repository import ContentRepositoryConfig +from .saml2 import SAML2Config +from .server import ServerConfig +from .server_notices_config import ServerNoticesConfig +from .spam_checker import SpamCheckerConfig +from .tls import TlsConfig +from .user_directory import UserDirectoryConfig +from .voip import VoipConfig +from .workers import WorkerConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -41,12 +46,16 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig,): + WorkerConfig, PasswordAuthProviderConfig, PushConfig, + SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, + ConsentConfig, + ServerNoticesConfig, + ): pass if __name__ == '__main__': import sys sys.stdout.write( - HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0] + HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0] ) diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 47f145c589..51e7f7e003 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -15,7 +15,6 @@ from ._base import Config, ConfigError - MISSING_JWT = ( """Missing jwt library. This is required for jwt login. diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..279c47bb48 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -13,21 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +import hashlib +import logging +import os -from synapse.util.stringutils import random_string from signedjson.key import ( - generate_signing_key, is_signing_algorithm_supported, - decode_signing_key_base64, decode_verify_key_bytes, - read_signing_keys, write_signing_keys, NACL_ED25519 + NACL_ED25519, + decode_signing_key_base64, + decode_verify_key_bytes, + generate_signing_key, + is_signing_algorithm_supported, + read_signing_keys, + write_signing_keys, ) from unpaddedbase64 import decode_base64 -from synapse.util.stringutils import random_string_with_symbols -import os -import hashlib -import logging +from synapse.util.stringutils import random_string, random_string_with_symbols +from ._base import Config, ConfigError logger = logging.getLogger(__name__) @@ -59,14 +62,20 @@ class KeyConfig(Config): self.expire_access_token = config.get("expire_access_token", False) + # a secret which is used to calculate HMACs for form values, to stop + # falsification of values + self.form_secret = config.get("form_secret", None) + def default_config(self, config_dir_path, server_name, is_generating_file=False, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) if is_generating_file: macaroon_secret_key = random_string_with_symbols(50) + form_secret = '"%s"' % random_string_with_symbols(50) else: macaroon_secret_key = None + form_secret = 'null' return """\ macaroon_secret_key: "%(macaroon_secret_key)s" @@ -74,6 +83,10 @@ class KeyConfig(Config): # Used to enable access token expiration. expire_access_token: False + # a secret which is used to calculate HMACs for form values, to stop + # falsification of values + form_secret: %(form_secret)s + ## Signing Keys ## # Path to the signing key to sign messages with @@ -118,10 +131,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +153,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2dbeafa9dd..a87b11a1df 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -12,43 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ._base import Config -from synapse.util.logcontext import LoggingContextFilter -from twisted.logger import globalLogBeginner, STDLibLogObserver import logging import logging.config -import yaml -from string import Template import os import signal +import sys +from string import Template +import yaml + +from twisted.logger import STDLibLogObserver, globalLogBeginner + +import synapse +from synapse.util.logcontext import LoggingContextFilter +from synapse.util.versionstring import get_version_string + +from ._base import Config DEFAULT_LOG_CONFIG = Template(""" version: 1 formatters: - precise: - format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ -- %(message)s' + precise: + format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ +%(request)s - %(message)s' filters: - context: - (): synapse.util.logcontext.LoggingContextFilter - request: "" + context: + (): synapse.util.logcontext.LoggingContextFilter + request: "" handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] - console: - class: logging.StreamHandler - formatter: precise - filters: [context] + file: + class: logging.handlers.RotatingFileHandler + formatter: precise + filename: ${log_file} + maxBytes: 104857600 + backupCount: 10 + filters: [context] + console: + class: logging.StreamHandler + formatter: precise + filters: [context] loggers: synapse: @@ -74,17 +79,10 @@ class LoggingConfig(Config): self.log_file = self.abspath(config.get("log_file")) def default_config(self, config_dir_path, server_name, **kwargs): - log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") ) return """ - # Logging verbosity level. Ignored if log_config is specified. - verbose: 0 - - # File to write logging to. Ignored if log_config is specified. - log_file: "%(log_file)s" - # A yaml python logging config file log_config: "%(log_config)s" """ % locals() @@ -123,9 +121,10 @@ class LoggingConfig(Config): def generate_files(self, config): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): - with open(log_config, "wb") as log_config_file: + log_file = self.abspath("homeserver.log") + with open(log_config, "w") as log_config_file: log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) + DEFAULT_LOG_CONFIG.substitute(log_file=log_file) ) @@ -148,8 +147,11 @@ def setup_logging(config, use_worker_options=False): "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" ) - if log_config is None: + if log_config is None: + # We don't have a logfile, so fall back to the 'verbosity' param from + # the config or cmdline. (Note that we generate a log config for new + # installs, so this will be an unusual case) level = logging.INFO level_for_storage = logging.INFO if config.verbosity: @@ -157,11 +159,10 @@ def setup_logging(config, use_worker_options=False): if config.verbosity > 1: level_for_storage = logging.DEBUG - # FIXME: we need a logging.WARN for a -q quiet option logger = logging.getLogger('') logger.setLevel(level) - logging.getLogger('synapse.storage').setLevel(level_for_storage) + logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: @@ -176,6 +177,10 @@ def setup_logging(config, use_worker_options=False): logger.info("Opened new log file due to SIGHUP") else: handler = logging.StreamHandler() + + def sighup(signum, stack): + pass + handler.setFormatter(formatter) handler.addFilter(LoggingContextFilter(request="")) @@ -202,6 +207,15 @@ def setup_logging(config, use_worker_options=False): if getattr(signal, "SIGHUP"): signal.signal(signal.SIGHUP, sighup) + # make sure that the first thing we log is a thing we can grep backwards + # for + logging.warn("***** STARTING SERVER *****") + logging.warn( + "Server %s version %s", + sys.argv[0], get_version_string(synapse), + ) + logging.info("Server hostname: %s", config.server_name) + # It's critical to point twisted's internal logging somewhere, otherwise it # stacks up and leaks kup to 64K object; # see: https://twistedmatrix.com/trac/ticket/8164 diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83762d089a..f4066abc28 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,44 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError +from synapse.util.module_loader import load_module -import importlib +from ._base import Config + +LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] + providers = [] # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) - self.ldap_enabled = ldap_config.get("enabled", False) - if self.ldap_enabled: - from ldap_auth_provider import LdapAuthProvider - parsed_config = LdapAuthProvider.parse_config(ldap_config) - self.password_providers.append((LdapAuthProvider, parsed_config)) + if ldap_config.get("enabled", False): + providers.append({ + 'module': LDAP_PROVIDER, + 'config': ldap_config, + }) - providers = config.get("password_providers", []) + providers.extend(config.get("password_providers", [])) for provider in providers: + mod_name = provider['module'] + # This is for backwards compat when the ldap auth provider resided # in this package. - if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": - from ldap_auth_provider import LdapAuthProvider - provider_class = LdapAuthProvider - else: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": + mod_name = LDAP_PROVIDER + + (provider_class, provider_config) = load_module({ + "module": mod_name, + "config": provider['config'], + }) - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) self.password_providers.append((provider_class, provider_config)) def default_config(self, **kwargs): diff --git a/synapse/config/push.py b/synapse/config/push.py index 9c68318b40..b7e0d46afa 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,28 +19,43 @@ from ._base import Config class PushConfig(Config): def read_config(self, config): - self.push_redact_content = False + push_config = config.get("push", {}) + self.push_include_content = push_config.get("include_content", True) + # There was a a 'redact_content' setting but mistakenly read from the + # 'email'section'. Check for the flag in the 'push' section, and log, + # but do not honour it to avoid nasty surprises when people upgrade. + if push_config.get("redact_content") is not None: + print( + "The push.redact_content content option has never worked. " + "Please set push.include_content if you want this behaviour" + ) + + # Now check for the one in the 'email' section and honour it, + # with a warning. push_config = config.get("email", {}) - self.push_redact_content = push_config.get("redact_content", False) + redact_content = push_config.get("redact_content") + if redact_content is not None: + print( + "The 'email.redact_content' option is deprecated: " + "please set push.include_content instead" + ) + self.push_include_content = not redact_content def default_config(self, config_dir_path, server_name, **kwargs): return """ - # Control how push messages are sent to google/apple to notifications. - # Normally every message said in a room with one or more people using - # mobile devices will be posted to a push server hosted by matrix.org - # which is registered with google and apple in order to allow push - # notifications to be sent to these mobile devices. - # - # Setting redact_content to true will make the push messages contain no - # message content which will provide increased privacy. This is a - # temporary solution pending improvements to Android and iPhone apps - # to get content from the app rather than the notification. - # + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + # For modern android devices the notification content will still appear # because it is loaded by the app. iPhone, however will send a # notification saying only that a message arrived and who it came from. # #push: - # redact_content: false + # include_content: true """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f7e03c4cde..0fb964eb67 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from distutils.util import strtobool from synapse.util.stringutils import random_string_with_symbols -from distutils.util import strtobool +from ._base import Config class RegistrationConfig(Config): @@ -31,6 +31,8 @@ class RegistrationConfig(Config): strtobool(str(config["disable_registration"])) ) + self.registrations_require_3pid = config.get("registrations_require_3pid", []) + self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -41,6 +43,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -50,13 +54,32 @@ class RegistrationConfig(Config): # Enable registration for new users. enable_registration: False + # The user must provide all of the below types of 3PID when registering. + # + # registrations_require_3pid: + # - email + # - msisdn + + # Mandate that users are only allowed to associate certain formats of + # 3PIDs with accounts on this server. + # + # allowed_local_3pids: + # - medium: email + # pattern: ".*@matrix\\.org" + # - medium: email + # pattern: ".*@vector\\.im" + # - medium: msisdn + # pattern: "\\+44" + # If set, allows registration by anyone who also has the shared # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. - # The default number of rounds is 12. + # The default number is 12 (which equates to 2^12 rounds). + # N.B. that increasing this will exponentially increase the time required + # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins. bcrypt_rounds: 12 # Allows users to register as guests without a password/email/etc, and @@ -70,6 +93,11 @@ class RegistrationConfig(Config): - matrix.org - vector.im - riot.im + + # Users who register on this homeserver will automatically be joined + # to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c6f57168e..fc909c1fac 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ConfigError from collections import namedtuple +from synapse.util.module_loader import load_module + +from ._base import Config, ConfigError MISSING_NETADDR = ( "Missing netaddr library. This is required for URL preview API." @@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) +MediaStorageProviderConfig = namedtuple( + "MediaStorageProviderConfig", ( + "store_local", # Whether to store newly uploaded local files + "store_remote", # Whether to store newly downloaded remote files + "store_synchronous", # Whether to wait for successful storage for local uploads + ), +) + def parse_thumbnail_requirements(thumbnail_sizes): """ Takes a list of dictionaries with "width", "height", and "method" keys @@ -70,7 +80,64 @@ class ContentRepositoryConfig(Config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_spider_size = self.parse_size(config["max_spider_size"]) + self.media_store_path = self.ensure_directory(config["media_store_path"]) + + backup_media_store_path = config.get("backup_media_store_path") + + synchronous_backup_media_store = config.get( + "synchronous_backup_media_store", False + ) + + storage_providers = config.get("media_storage_providers", []) + + if backup_media_store_path: + if storage_providers: + raise ConfigError( + "Cannot use both 'backup_media_store_path' and 'storage_providers'" + ) + + storage_providers = [{ + "module": "file_system", + "store_local": True, + "store_synchronous": synchronous_backup_media_store, + "store_remote": True, + "config": { + "directory": backup_media_store_path, + } + }] + + # This is a list of config that can be used to create the storage + # providers. The entries are tuples of (Class, class_config, + # MediaStorageProviderConfig), where Class is the class of the provider, + # the class_config the config to pass to it, and + # MediaStorageProviderConfig are options for StorageProviderWrapper. + # + # We don't create the storage providers here as not all workers need + # them to be started. + self.media_storage_providers = [] + + for provider_config in storage_providers: + # We special case the module "file_system" so as not to need to + # expose FileStorageProviderBackend + if provider_config["module"] == "file_system": + provider_config["module"] = ( + "synapse.rest.media.v1.storage_provider" + ".FileStorageProviderBackend" + ) + + provider_class, parsed_config = load_module(provider_config) + + wrapper_config = MediaStorageProviderConfig( + provider_config.get("store_local", False), + provider_config.get("store_remote", False), + provider_config.get("store_synchronous", False), + ) + + self.media_storage_providers.append( + (provider_class, parsed_config, wrapper_config,) + ) + self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( @@ -115,6 +182,20 @@ class ContentRepositoryConfig(Config): # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" + # Media storage providers allow media to be stored in different + # locations. + # media_storage_providers: + # - module: file_system + # # Whether to write new local files. + # store_local: false + # # Whether to write new remote media + # store_remote: false + # # Whether to block upload requests waiting for write to this + # # provider to complete + # store_synchronous: false + # config: + # directory: /mnt/some/other/directory + # Directory where in-progress uploads are stored. uploads_path: "%(uploads_path)s" @@ -169,6 +250,9 @@ class ContentRepositoryConfig(Config): # - '192.168.0.0/16' # - '100.64.0.0/10' # - '169.254.0.0/16' + # - '::1/128' + # - 'fe80::/64' + # - 'fc00::/7' # # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. diff --git a/synapse/config/server.py b/synapse/config/server.py index 28b4e5f50c..18102656b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from synapse.http.endpoint import parse_and_validate_server_name + from ._base import Config, ConfigError +logger = logging.Logger(__name__) + class ServerConfig(Config): def read_config(self, config): self.server_name = config["server_name"] + + try: + parse_and_validate_server_name(self.server_name) + except ValueError as e: + raise ConfigError(str(e)) + self.pid_file = self.abspath(config.get("pid_file")) self.web_client = config["web_client"] self.web_client_location = config.get("web_client_location", None) @@ -29,6 +42,7 @@ class ServerConfig(Config): self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") + self.cpu_affinity = config.get("cpu_affinity") # Whether to send federation traffic out in this process. This only # applies to some federation traffic, and so shouldn't be used to @@ -39,8 +53,31 @@ class ServerConfig(Config): # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) + # whether to enable the media repository endpoints. This should be set + # to false if the media repository is running as a separate endpoint; + # doing so ensures that we will not run cache cleanup jobs on the + # master, potentially causing inconsistency. + self.enable_media_repo = config.get("enable_media_repo", True) + self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + # Whether we should block invites sent to users on this server + # (other than those sent by local server admins) + self.block_non_admin_invites = config.get( + "block_non_admin_invites", False, + ) + + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None + federation_domain_whitelist = config.get( + "federation_domain_whitelist", None + ) + # turn the whitelist into a hash for speed of lookup + if federation_domain_whitelist is not None: + self.federation_domain_whitelist = {} + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -113,6 +150,12 @@ class ServerConfig(Config): metrics_port = config.get("metrics_port") if metrics_port: + logger.warn( + ("The metrics_port configuration option is deprecated in Synapse 0.31 " + "in favour of a listener. Please see " + "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" + " on how to configure the new listener.")) + self.listeners.append({ "port": metrics_port, "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], @@ -127,8 +170,8 @@ class ServerConfig(Config): }) def default_config(self, server_name, **kwargs): - if ":" in server_name: - bind_port = int(server_name.split(":")[1]) + _, bind_port = parse_and_validate_server_name(server_name) + if bind_port is not None: unsecure_port = bind_port - 400 else: bind_port = 8448 @@ -147,6 +190,27 @@ class ServerConfig(Config): # When running as a daemon, the file to store the pid in pid_file: %(pid_file)s + # CPU affinity mask. Setting this restricts the CPUs on which the + # process will be scheduled. It is represented as a bitmask, with the + # lowest order bit corresponding to the first logical CPU and the + # highest order bit corresponding to the last logical CPU. Not all CPUs + # may exist on a given system but a mask may specify more CPUs than are + # present. + # + # For example: + # 0x00000001 is processor #0, + # 0x00000003 is processors #0 and #1, + # 0xFFFFFFFF is all processors (#0 through #31). + # + # Pinning a Python process to a single CPU is desirable, because Python + # is inherently single-threaded due to the GIL, and can suffer a + # 30-40%% slowdown due to cache blow-out and thread context switching + # if the scheduler happens to schedule the underlying threads across + # different cores. See + # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. + # + # cpu_affinity: 0xFFFFFFFF + # Whether to serve a web client from the HTTP/HTTPS root resource. web_client: True @@ -171,6 +235,21 @@ class ServerConfig(Config): # and sync operations. The default value is -1, means no upper limit. # filter_timeline_limit: 5000 + # Whether room invites to users on this server should be blocked + # (except those sent by local server admins). The default is False. + # block_non_admin_invites: True + + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + # federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: @@ -181,13 +260,12 @@ class ServerConfig(Config): port: %(bind_port)s # Local addresses to listen on. - # This will listen on all IPv4 addresses by default. + # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6 + # addresses by default. For most other OSes, this will only listen + # on IPv6. bind_addresses: + - '::' - '0.0.0.0' - # Uncomment to listen on all IPv6 interfaces - # N.B: On at least Linux this will also listen on all IPv4 - # addresses, so you will need to comment out the line above. - # - '::' # This is a 'http' listener, allows us to specify 'resources'. type: http @@ -214,11 +292,18 @@ class ServerConfig(Config): - names: [federation] # Federation APIs compress: false + # optional list of additional endpoints which can be loaded via + # dynamic modules + # additional_resources: + # "/_matrix/my/custom/endpoint": + # module: my_module.CustomRequestHandler + # config: {} + # Unsecure HTTP listener, # For when matrix traffic passes through loadbalancer that unwraps TLS. - port: %(unsecure_port)s tls: false - bind_addresses: ['0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] type: http x_forwarded: false @@ -232,7 +317,7 @@ class ServerConfig(Config): # Turn on the twisted ssh manhole service on localhost on the given # port. # - port: 9000 - # bind_address: 127.0.0.1 + # bind_addresses: ['::1', '127.0.0.1'] # type: manhole """ % locals() @@ -270,7 +355,7 @@ def read_gc_thresholds(thresholds): return ( int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), ) - except: + except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" ) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py new file mode 100644 index 0000000000..3c39850ac6 --- /dev/null +++ b/synapse/config/server_notices_config.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.types import UserID + +from ._base import Config + +DEFAULT_CONFIG = """\ +# Server Notices room configuration +# +# Uncomment this section to enable a room which can be used to send notices +# from the server to users. It is a special room which cannot be left; notices +# come from a special "notices" user id. +# +# If you uncomment this section, you *must* define the system_mxid_localpart +# setting, which defines the id of the user which will be used to send the +# notices. +# +# It's also possible to override the room name, the display name of the +# "notices" user, and the avatar for the user. +# +# server_notices: +# system_mxid_localpart: notices +# system_mxid_display_name: "Server Notices" +# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ" +# room_name: "Server Notices" +""" + + +class ServerNoticesConfig(Config): + """Configuration for the server notices room. + + Attributes: + server_notices_mxid (str|None): + The MXID to use for server notices. + None if server notices are not enabled. + + server_notices_mxid_display_name (str|None): + The display name to use for the server notices user. + None if server notices are not enabled. + + server_notices_mxid_avatar_url (str|None): + The display name to use for the server notices user. + None if server notices are not enabled. + + server_notices_room_name (str|None): + The name to use for the server notices room. + None if server notices are not enabled. + """ + def __init__(self): + super(ServerNoticesConfig, self).__init__() + self.server_notices_mxid = None + self.server_notices_mxid_display_name = None + self.server_notices_mxid_avatar_url = None + self.server_notices_room_name = None + + def read_config(self, config): + c = config.get("server_notices") + if c is None: + return + + mxid_localpart = c['system_mxid_localpart'] + self.server_notices_mxid = UserID( + mxid_localpart, self.server_name, + ).to_string() + self.server_notices_mxid_display_name = c.get( + 'system_mxid_display_name', None, + ) + self.server_notices_mxid_avatar_url = c.get( + 'system_mxid_avatar_url', None, + ) + # todo: i18n + self.server_notices_room_name = c.get('room_name', "Server Notices") + + def default_config(self, **kwargs): + return DEFAULT_CONFIG diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py new file mode 100644 index 0000000000..3fec42bdb0 --- /dev/null +++ b/synapse/config/spam_checker.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.module_loader import load_module + +from ._base import Config + + +class SpamCheckerConfig(Config): + def read_config(self, config): + self.spam_checker = None + + provider = config.get("spam_checker", None) + if provider is not None: + self.spam_checker = load_module(provider) + + def default_config(self, **kwargs): + return """\ + # spam_checker: + # module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + """ diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..fef1ea99cb 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - -from OpenSSL import crypto -import subprocess import os - +import subprocess from hashlib import sha256 + from unpaddedbase64 import encode_base64 +from OpenSSL import crypto + +from ._base import Config + GENERATE_DH_PARAMS = False @@ -96,7 +97,7 @@ class TlsConfig(Config): # certificates returned by this server match one of the fingerprints. # # Synapse automatically adds the fingerprint of its own certificate - # to the list. So if federation traffic is handle directly by synapse + # to the list. So if federation traffic is handled directly by synapse # then no modification to the list is required. # # If synapse is run behind a load balancer that handles the TLS then it @@ -109,6 +110,12 @@ class TlsConfig(Config): # key. It may be necessary to publish the fingerprints of a new # certificate and wait until the "valid_until_ts" of the previous key # responses have passed before deploying it. + # + # You can calculate a fingerprint from a given TLS listener via: + # openssl s_client -connect $host:$port < /dev/null 2> /dev/null | + # openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '=' + # or by checking matrix.org/federationtester/api/report?server_name=$host + # tls_fingerprints: [] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ % locals() @@ -126,8 +133,8 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): - with open(tls_private_key_path, "w") as private_key_file: + if not self.path_exists(tls_private_key_path): + with open(tls_private_key_path, "wb") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) private_key_pem = crypto.dump_privatekey( @@ -141,8 +148,8 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): - with open(tls_certificate_path, "w") as certificate_file: + if not self.path_exists(tls_certificate_path): + with open(tls_certificate_path, "wb") as certificate_file: cert = crypto.X509() subject = cert.get_subject() subject.CN = config["server_name"] @@ -159,7 +166,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py new file mode 100644 index 0000000000..38e8947843 --- /dev/null +++ b/synapse/config/user_directory.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class UserDirectoryConfig(Config): + """User Directory Configuration + Configuration for the behaviour of the /user_directory API + """ + + def read_config(self, config): + self.user_directory_search_all_users = False + user_directory_config = config.get("user_directory", None) + if user_directory_config: + self.user_directory_search_all_users = ( + user_directory_config.get("search_all_users", False) + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # User Directory configuration + # + # 'search_all_users' defines whether to search all users visible to your HS + # when searching the user directory, rather than limiting to users visible + # in public rooms. Defaults to false. If you set it True, you'll have to run + # UPDATE user_directory_stream_pos SET stream_id = NULL; + # on your database to tell it to rebuild the user_directory search indexes. + # + #user_directory: + # search_all_users: false + """ diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 3a4e16fa96..d07bd24ffd 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -30,10 +30,10 @@ class VoipConfig(Config): ## Turn ## # The public URIs of the TURN server to give to clients - turn_uris: [] + #turn_uris: [] # The shared secret used to compute passwords for the TURN server - turn_shared_secret: "YOUR_SHARED_SECRET" + #turn_shared_secret: "YOUR_SHARED_SECRET" # The Username and password if the TURN server needs them and # does not use a token diff --git a/synapse/config/workers.py b/synapse/config/workers.py index ea48d931a1..80baf0ce0e 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -23,15 +23,31 @@ class WorkerConfig(Config): def read_config(self, config): self.worker_app = config.get("worker_app") + + # Canonicalise worker_app so that master always has None + if self.worker_app == "synapse.app.homeserver": + self.worker_app = None + self.worker_listeners = config.get("worker_listeners") self.worker_daemonize = config.get("worker_daemonize") self.worker_pid_file = config.get("worker_pid_file") self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") + + # The host used to connect to the main synapse self.worker_replication_host = config.get("worker_replication_host", None) + + # The port on the main synapse for TCP replication self.worker_replication_port = config.get("worker_replication_port", None) + + # The port on the main synapse for HTTP replication endpoint + self.worker_replication_http_port = config.get("worker_replication_http_port") + self.worker_name = config.get("worker_name", self.worker_app) + self.worker_main_http_uri = config.get("worker_main_http_uri", None) + self.worker_cpu_affinity = config.get("worker_cpu_affinity") + if self.worker_listeners: for listener in self.worker_listeners: bind_address = listener.pop("bind_address", None) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index aad4752fe7..a1e1d0d33a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import ssl -from OpenSSL import SSL -from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName - import logging +from OpenSSL import SSL, crypto +from twisted.internet import ssl +from twisted.internet._sslverify import _defaultCurveName + logger = logging.getLogger(__name__) @@ -32,9 +32,10 @@ class ServerContextFactory(ssl.ContextFactory): @staticmethod def configure_context(context, config): try: - _ecCurve = _OpenSSLECCurve(_defaultCurveName) - _ecCurve.addECKeyToContext(context) - except: + _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) + context.set_tmp_ecdh(_ecCurve) + + except Exception: logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate_chain_file(config.tls_certificate_file) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index ec7711ba7d..8774b28967 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -15,15 +15,15 @@ # limitations under the License. -from synapse.api.errors import SynapseError, Codes -from synapse.events.utils import prune_event +import hashlib +import logging from canonicaljson import encode_canonical_json -from unpaddedbase64 import encode_base64, decode_base64 from signedjson.sign import sign_json +from unpaddedbase64 import decode_base64, encode_base64 -import hashlib -import logging +from synapse.api.errors import Codes, SynapseError +from synapse.events.utils import prune_event logger = logging.getLogger(__name__) @@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event, hash_algorithm) logger.debug("Expecting hash: %s", encode_base64(expected_hash)) - if name not in event.hashes: + + # some malformed events lack a 'hashes'. Protect against it being missing + # or a weird type by basically treating it the same as an unhashed event. + hashes = event.get("hashes") + if not isinstance(hashes, dict): + raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + + if name not in hashes: raise SynapseError( 400, "Algorithm %s not in hashes %s" % ( - name, list(event.hashes), + name, list(hashes), ), Codes.UNAUTHORIZED, ) - message_hash_base64 = event.hashes[name] + message_hash_base64 = hashes[name] try: message_hash_bytes = decode_base64(message_hash_base64) - except: + except Exception: raise SynapseError( 400, "Invalid base64: %s" % (message_hash_base64,), diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c2bd64d6c2..668b4f517d 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -13,17 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from canonicaljson import json -from twisted.web.http import HTTPClient -from twisted.internet.protocol import Factory from twisted.internet import defer, reactor -from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import ( - preserve_context_over_fn, preserve_context_over_deferred -) -import simplejson as json -import logging +from twisted.internet.protocol import Factory +from twisted.web.http import HTTPClient +from synapse.http.endpoint import matrix_federation_endpoint +from synapse.util import logcontext logger = logging.getLogger(__name__) @@ -43,14 +42,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - protocol = yield preserve_context_over_fn( - endpoint.connect, factory - ) - server_response, server_certificate = yield preserve_context_over_deferred( - protocol.remote_key - ) - defer.returnValue((server_response, server_certificate)) - return + with logcontext.PreserveLoggingContext(): + protocol = yield endpoint.connect(factory) + server_response, server_certificate = yield protocol.remote_key + defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 1bb27edc0f..e95b9fb43e 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +14,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.crypto.keyclient import fetch_server_key -from synapse.api.errors import SynapseError, Codes -from synapse.util import unwrapFirstError -from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import ( - preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, - preserve_fn -) -from synapse.util.metrics import Measure - -from twisted.internet import defer +import hashlib +import logging +import urllib +from collections import namedtuple -from signedjson.sign import ( - verify_signed_json, signature_ids, sign_json, encode_canonical_json -) from signedjson.key import ( - is_signing_algorithm_supported, decode_verify_key_bytes + decode_verify_key_bytes, + encode_verify_key_base64, + is_signing_algorithm_supported, +) +from signedjson.sign import ( + SignatureVerifyException, + encode_canonical_json, + sign_json, + signature_ids, + verify_signed_json, ) from unpaddedbase64 import decode_base64, encode_base64 from OpenSSL import crypto +from twisted.internet import defer -from collections import namedtuple -import urllib -import hashlib -import logging - +from synapse.api.errors import Codes, SynapseError +from synapse.crypto.keyclient import fetch_server_key +from synapse.util import logcontext, unwrapFirstError +from synapse.util.logcontext import ( + PreserveLoggingContext, + preserve_fn, + run_in_background, +) +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -55,9 +60,10 @@ Attributes: key_ids(set(str)): The set of key_ids to that could be used to verify the JSON object json_object(dict): The JSON object to verify. - deferred(twisted.internet.defer.Deferred): + deferred(Deferred[str, str, nacl.signing.VerifyKey]): A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched + a verify key has been fetched. The deferreds' callbacks are run with no + logcontext. """ @@ -74,23 +80,32 @@ class Keyring(object): self.perspective_servers = self.config.perspectives self.hs = hs + # map from server name to Deferred. Has an entry for each server with + # an ongoing key download; the Deferred completes once the download + # completes. + # + # These are regular, logcontext-agnostic Deferreds. self.key_downloads = {} def verify_json_for_server(self, server_name, json_object): - return self.verify_json_objects_for_server( - [(server_name, json_object)] - )[0] + return logcontext.make_deferred_yieldable( + self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] + ) def verify_json_objects_for_server(self, server_and_json): - """Bulk verfies signatures of json objects, bulk fetching keys as + """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: server_and_json (list): List of pairs of (server_name, json_object) Returns: - list of deferreds indicating success or failure to verify each - json object's signature for the given server_name. + List<Deferred>: for each input pair, a deferred indicating success + or failure to verify each json object's signature for the given + server_name. The deferreds run their callbacks in the sentinel + logcontext. """ verify_requests = [] @@ -117,73 +132,60 @@ class Keyring(object): verify_requests.append(verify_request) - @defer.inlineCallbacks - def handle_key_deferred(verify_request): - server_name = verify_request.server_name - try: - _, key_id, verify_key = yield verify_request.deferred - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.exception( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + run_in_background(self._start_key_lookups, verify_requests) - json_object = verify_request.json_object + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified + handle = preserve_fn(_handle_key_deferred) + return [ + handle(rq) for rq in verify_requests + ] - logger.debug("Got key %s %s:%s for server %s, verifying" % ( - key_id, verify_key.alg, verify_key.version, server_name, - )) - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + @defer.inlineCallbacks + def _start_key_lookups(self, verify_requests): + """Sets off the key fetches for each verify request - server_to_deferred = { - server_name: defer.Deferred() - for server_name, _ in server_and_json - } + Once each fetch completes, verify_request.deferred will be resolved. - with PreserveLoggingContext(): + Args: + verify_requests (List[VerifyKeyRequest]): + """ + + try: + # create a deferred for each server we're going to look up the keys + # for; we'll resolve them once we have completed our lookups. + # These will be passed into wait_for_previous_lookups to block + # any other lookups until we have finished. + # The deferreds are called with no logcontext. + server_to_deferred = { + rq.server_name: defer.Deferred() + for rq in verify_requests + } # We want to wait for any previous lookups to complete before # proceeding. - wait_on_deferred = self.wait_for_previous_lookups( - [server_name for server_name, _ in server_and_json], + yield self.wait_for_previous_lookups( + [rq.server_name for rq in verify_requests], server_to_deferred, ) # Actually start fetching keys. - wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(verify_requests) - ) + self._get_server_verify_keys(verify_requests) # When we've finished fetching all the keys for a given server_name, # resolve the deferred passed to `wait_for_previous_lookups` so that # any lookups waiting will proceed. + # + # map from server name to a set of request ids server_to_request_ids = {} - def remove_deferreds(res, server_name, verify_request): + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) + + def remove_deferreds(res, verify_request): + server_name = verify_request.server_name request_id = id(verify_request) server_to_request_ids[server_name].discard(request_id) if not server_to_request_ids[server_name]: @@ -193,17 +195,11 @@ class Keyring(object): return res for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - deferred.addBoth(remove_deferreds, server_name, verify_request) - - # Pass those keys to handle_key_deferred so that the json object - # signatures can be verified - return [ - preserve_context_over_fn(handle_key_deferred, verify_request) - for verify_request in verify_requests - ] + verify_request.deferred.addBoth( + remove_deferreds, verify_request, + ) + except Exception: + logger.exception("Error starting key lookups") @defer.inlineCallbacks def wait_for_previous_lookups(self, server_names, server_to_deferred): @@ -212,7 +208,13 @@ class Keyring(object): Args: server_names (list): list of server_names we want to lookup server_to_deferred (dict): server_name to deferred which gets - resolved once we've finished looking up keys for that server + resolved once we've finished looking up keys for that server. + The Deferreds should be regular twisted ones which call their + callbacks with no logcontext. + + Returns: a Deferred which resolves once all key lookups for the given + servers have completed. Follows the synapse rules of logcontext + preservation. """ while True: wait_on = [ @@ -226,17 +228,15 @@ class Keyring(object): else: break - for server_name, deferred in server_to_deferred.items(): - d = ObservableDeferred(preserve_context_over_deferred(deferred)) - self.key_downloads[server_name] = d - - def rm(r, server_name): - self.key_downloads.pop(server_name, None) - return r + def rm(r, server_name_): + self.key_downloads.pop(server_name_, None) + return r - d.addBoth(rm, server_name) + for server_name, deferred in server_to_deferred.items(): + self.key_downloads[server_name] = deferred + deferred.addBoth(rm, server_name) - def get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests): """Tries to find at least one key for each verify request For each verify_request, verify_request.deferred is called back with @@ -305,21 +305,23 @@ class Keyring(object): if not missing_keys: break - for verify_request in requests_missing_keys.values(): - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + with PreserveLoggingContext(): + for verify_request in requests_missing_keys: + verify_request.deferred.errback(SynapseError( + 401, + "No key for %s with id %s" % ( + verify_request.server_name, verify_request.key_ids, + ), + Codes.UNAUTHORIZED, + )) def on_err(err): - for verify_request in verify_requests: - if not verify_request.deferred.called: - verify_request.deferred.errback(err) + with PreserveLoggingContext(): + for verify_request in verify_requests: + if not verify_request.deferred.called: + verify_request.deferred.errback(err) - do_iterations().addErrback(on_err) + run_in_background(do_iterations).addErrback(on_err) @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): @@ -333,15 +335,16 @@ class Keyring(object): Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from server_name -> key_id -> VerifyKey """ - res = yield preserve_context_over_deferred(defer.gatherResults( + res = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.get_server_verify_keys)( - server_name, key_ids + run_in_background( + self.store.get_server_verify_keys, + server_name, key_ids, ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(dict(res)) @@ -358,17 +361,17 @@ class Keyring(object): logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) defer.returnValue({}) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(p_name, p_keys) + run_in_background(get_key, p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) union_of_keys = {} for result in results: @@ -390,7 +393,7 @@ class Keyring(object): logger.info( "Unable to get key %r for %r directly: %s %s", key_ids, server_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) if not keys: @@ -402,13 +405,13 @@ class Keyring(object): defer.returnValue(keys) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(server_name, key_ids) + run_in_background(get_key, server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) merged = {} for result in results: @@ -485,9 +488,10 @@ class Keyring(object): for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(response_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -495,7 +499,7 @@ class Keyring(object): for server_name, response_keys in keys.items() ], consumeErrors=True - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(keys) @@ -543,9 +547,10 @@ class Keyring(object): keys.update(response_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, @@ -553,7 +558,7 @@ class Keyring(object): for key_server_name, verify_keys in keys.items() ], consumeErrors=True - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(keys) @@ -619,9 +624,10 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield preserve_context_over_deferred(defer.gatherResults( + yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_keys_json)( + run_in_background( + self.store.store_server_keys_json, server_name=server_name, key_id=key_id, from_server=server_name, @@ -632,7 +638,7 @@ class Keyring(object): for key_id in updated_key_ids ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) results[server_name] = response_keys @@ -710,7 +716,6 @@ class Keyring(object): defer.returnValue(verify_keys) - @defer.inlineCallbacks def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server Args: @@ -721,12 +726,74 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - yield preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_verify_key)( + run_in_background( + self.store.store_server_verify_key, server_name, server_name, key.time_added, key ) for key_id, key in verify_keys.items() ], consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) + + +@defer.inlineCallbacks +def _handle_key_deferred(verify_request): + """Waits for the key to become available, and then performs a verification + + Args: + verify_request (VerifyKeyRequest): + + Returns: + Deferred[None] + + Raises: + SynapseError if there was a problem performing the verification + """ + server_name = verify_request.server_name + try: + with PreserveLoggingContext(): + _, key_id, verify_key = yield verify_request.deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, verify_request.key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = verify_request.json_object + + logger.debug("Got key %s %s:%s for server %s, verifying" % ( + key_id, verify_key.alg, verify_key.version, server_name, + )) + try: + verify_signed_json(json_object, server_name, verify_key) + except SignatureVerifyException as e: + logger.debug( + "Error verifying signature for %s:%s:%s with key %s: %s", + server_name, verify_key.alg, verify_key.version, + encode_verify_key_base64(verify_key), + str(e), + ) + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s: %s" % ( + server_name, verify_key.alg, verify_key.version, str(e), + ), + Codes.UNAUTHORIZED, + ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4096c606f1..b32f64e729 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -17,11 +17,11 @@ import logging from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException +from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) @@ -34,9 +34,11 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): event: the event being checked. auth_events (dict: event-key -> event): the existing room state. + Raises: + AuthError if the checks fail Returns: - True if the auth checks pass. + if the auth checks pass. """ if do_size_check: _check_size_limits(event) @@ -71,9 +73,10 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) logger.warn("Trusting event: %s", event.event_id) - return True + return if event.type == EventTypes.Create: + sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) if room_id_domain != sender_domain: raise AuthError( @@ -81,7 +84,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): "Creation event's room_id domain does not match sender's" ) # FIXME - return True + logger.debug("Allowing! %s", event) + return creation_event = auth_events.get((EventTypes.Create, ""), None) @@ -118,7 +122,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): 403, "Alias event's state_key does not match sender's domain" ) - return True + logger.debug("Allowing! %s", event) + return if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -127,14 +132,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): ) if event.type == EventTypes.Member: - allowed = _is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed + _is_membership_change_allowed(event, auth_events) + logger.debug("Allowing! %s", event) + return _check_event_sender_in_room(event, auth_events) @@ -153,7 +153,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): ) ) else: - return True + logger.debug("Allowing! %s", event) + return _can_send_event(event, auth_events) @@ -200,7 +201,7 @@ def _is_membership_change_allowed(event, auth_events): create = auth_events.get(key) if create and event.prev_events[0][0] == create.event_id: if create.content["creator"] == event.state_key: - return True + return target_user_id = event.state_key @@ -265,13 +266,13 @@ def _is_membership_change_allowed(event, auth_events): raise AuthError( 403, "%s is banned from the room" % (target_user_id,) ) - return True + return if Membership.JOIN != membership: if (caller_invited and Membership.LEAVE == membership and target_user_id == event.user_id): - return True + return if not caller_in_room: # caller isn't joined raise AuthError( @@ -319,7 +320,7 @@ def _is_membership_change_allowed(event, auth_events): # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) + 403, "You cannot unban user %s." % (target_user_id,) ) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) @@ -334,8 +335,6 @@ def _is_membership_change_allowed(event, auth_events): else: raise AuthError(500, "Unknown membership %s" % membership) - return True - def _check_event_sender_in_room(event, auth_events): key = (EventTypes.Member, event.user_id, ) @@ -355,35 +354,46 @@ def _check_joined_room(member, user_id, room_id): )) -def get_send_level(etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) +def get_send_level(etype, state_key, power_levels_event): + """Get the power level required to send an event of a given type + + The federation spec [1] refers to this as "Required Power Level". + + https://matrix.org/docs/spec/server_server/unstable.html#definitions - if send_level: - send_level = int(send_level) + Args: + etype (str): type of event + state_key (str|None): state_key of state event, or None if it is not + a state event. + power_levels_event (synapse.events.EventBase|None): power levels event + in force at this point in the room + Returns: + int: power level required to send this event. + """ + + if power_levels_event: + power_levels_content = power_levels_event.content else: - send_level = 0 + power_levels_content = {} + + # see if we have a custom level for this event type + send_level = power_levels_content.get("events", {}).get(etype) + + # otherwise, fall back to the state_default/events_default. + if send_level is None: + if state_key is not None: + send_level = power_levels_content.get("state_default", 50) + else: + send_level = power_levels_content.get("events_default", 0) - return send_level + return int(send_level) def _can_send_event(event, auth_events): + power_levels_event = _get_power_level_event(auth_events) + send_level = get_send_level( - event.type, event.get("state_key", None), auth_events + event.type, event.get("state_key"), power_levels_event, ) user_level = get_user_power_level(event.user_id, auth_events) @@ -443,12 +453,12 @@ def _check_power_levels(event, auth_events): for k, v in user_list.items(): try: UserID.from_string(k) - except: + except Exception: raise SynapseError(400, "Not a valid user_id: %s" % (k,)) try: int(v) - except: + except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) key = (event.type, event.state_key, ) @@ -470,15 +480,15 @@ def _check_power_levels(event, auth_events): ("invite", None), ] - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): + old_list = current_state.content.get("users", {}) + for user in set(list(old_list) + list(user_list)): levels_to_check.append( (user, "users") ) - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): + old_list = current_state.content.get("events", {}) + new_list = event.content.get("events", {}) + for ev_id in set(list(old_list) + list(new_list)): levels_to_check.append( (ev_id, "events") ) @@ -515,7 +525,11 @@ def _check_power_levels(event, auth_events): "to your own" ) - if old_level > user_level or new_level > user_level: + # Check if the old and new levels are greater than the user level + # (if defined) + old_level_too_big = old_level is not None and old_level > user_level + new_level_too_big = new_level is not None and new_level > user_level + if old_level_too_big or new_level_too_big: raise AuthError( 403, "You don't have permission to add ops level greater " @@ -524,13 +538,22 @@ def _check_power_levels(event, auth_events): def _get_power_level_event(auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) + return auth_events.get((EventTypes.PowerLevels, "")) def get_user_power_level(user_id, auth_events): - power_level_event = _get_power_level_event(auth_events) + """Get a user's power level + + Args: + user_id (str): user's id to look up in power_levels + auth_events (dict[(str, str), synapse.events.EventBase]): + state in force at this point in the room (or rather, a subset of + it including at least the create event and power levels event. + Returns: + int: the user's power level in this room. + """ + power_level_event = _get_power_level_event(auth_events) if power_level_event: level = power_level_event.content.get("users", {}).get(user_id) if not level: @@ -541,6 +564,11 @@ def get_user_power_level(user_id, auth_events): else: return int(level) else: + # if there is no power levels event, the creator gets 100 and everyone + # else gets 0. + + # some things which call this don't pass the create event: hack around + # that. key = (EventTypes.Create, "", ) create_event = auth_events.get(key) if (create_event is not None and diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index e673e96cc0..51f9084b90 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.frozenutils import freeze from synapse.util.caches import intern_dict - +from synapse.util.frozenutils import freeze # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting @@ -47,14 +46,26 @@ class _EventInternalMetadata(object): def _event_dict_property(key): + # We want to be able to use hasattr with the event dict properties. + # However, (on python3) hasattr expects AttributeError to be raised. Hence, + # we need to transform the KeyError into an AttributeError def getter(self): - return self._event_dict[key] + try: + return self._event_dict[key] + except KeyError: + raise AttributeError(key) def setter(self, v): - self._event_dict[key] = v + try: + self._event_dict[key] = v + except KeyError: + raise AttributeError(key) def delete(self): - del self._event_dict[key] + try: + del self._event_dict[key] + except KeyError: + raise AttributeError(key) return property( getter, @@ -134,7 +145,7 @@ class EventBase(object): return field in self._event_dict def items(self): - return self._event_dict.items() + return list(self._event_dict.items()) class FrozenEvent(EventBase): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 365fd96bd2..e662eaef10 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent, _event_dict_property +import copy from synapse.types import EventID - from synapse.util.stringutils import random_string -import copy +from . import EventBase, FrozenEvent, _event_dict_property class EventBuilder(EventBase): @@ -55,7 +54,7 @@ class EventBuilderFactory(object): local_part = str(int(self.clock.time())) + i + random_string(5) - e_id = EventID.create(local_part, self.hostname) + e_id = EventID(local_part, self.hostname) return e_id.to_string() diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e9a732ff03..368b5f6ae4 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,19 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + +from frozendict import frozendict + +from twisted.internet import defer + +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + class EventContext(object): """ Attributes: - current_state_ids (dict[(str, str), str]): - The current state map including the current event. - (type, state_key) -> event_id - - prev_state_ids (dict[(str, str), str]): - The current state map excluding the current event. - (type, state_key) -> event_id - - state_group (int): state group id + state_group (int|None): state group id, if the state has been stored + as a state group. This is usually only None if e.g. the event is + an outlier. rejected (bool|str): A rejection reason if the event was rejected, else False @@ -39,35 +41,250 @@ class EventContext(object): prev_state_events (?): XXX: is this ever set to anything other than the empty list? + + _current_state_ids (dict[(str, str), str]|None): + The current state map including the current event. None if outlier + or we haven't fetched the state from DB yet. + (type, state_key) -> event_id + + _prev_state_ids (dict[(str, str), str]|None): + The current state map excluding the current event. None if outlier + or we haven't fetched the state from DB yet. + (type, state_key) -> event_id + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + Only set when state has not been fetched yet. + + _event_state_key (str|None): The state_key of the event the context is + associated with. Only set when state has not been fetched yet. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + Only set when state has not been fetched yet. """ __slots__ = [ - "current_state_ids", - "prev_state_ids", "state_group", "rejected", - "push_actions", "prev_group", "delta_ids", "prev_state_events", "app_service", + "_current_state_ids", + "_prev_state_ids", + "_prev_state_id", + "_event_type", + "_event_state_key", + "_fetching_state_deferred", ] def __init__(self): + self.prev_state_events = [] + self.rejected = False + self.app_service = None + + @staticmethod + def with_state(state_group, current_state_ids, prev_state_ids, + prev_group=None, delta_ids=None): + context = EventContext() + # The current state including the current event - self.current_state_ids = None + context._current_state_ids = current_state_ids # The current state excluding the current event - self.prev_state_ids = None - self.state_group = None + context._prev_state_ids = prev_state_ids + context.state_group = state_group - self.rejected = False - self.push_actions = [] + context._prev_state_id = None + context._event_type = None + context._event_state_key = None + context._fetching_state_deferred = defer.succeed(None) # A previously persisted state group and a delta between that # and this state. - self.prev_group = None - self.delta_ids = None + context.prev_group = prev_group + context.delta_ids = delta_ids - self.prev_state_events = None + return context - self.app_service = None + @defer.inlineCallbacks + def serialize(self, event, store): + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Args: + event (FrozenEvent): The event that this context relates to + + Returns: + dict + """ + + # We don't serialize the full state dicts, instead they get pulled out + # of the DB on the other side. However, the other side can't figure out + # the prev_state_ids, so if we're a state event we include the event + # id that we replaced in the state. + if event.is_state(): + prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) + else: + prev_state_id = None + + defer.returnValue({ + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None + }) + + @staticmethod + def deserialize(store, input): + """Converts a dict that was produced by `serialize` back into a + EventContext. + + Args: + store (DataStore): Used to convert AS ID to AS object + input (dict): A dict produced by `serialize` + + Returns: + EventContext + """ + context = EventContext() + + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + context._prev_state_id = input["prev_state_id"] + context._event_type = input["event_type"] + context._event_state_key = input["event_state_key"] + + context._current_state_ids = None + context._prev_state_ids = None + context._fetching_state_deferred = None + + context.state_group = input["state_group"] + context.prev_group = input["prev_group"] + context.delta_ids = _decode_state_dict(input["delta_ids"]) + + context.rejected = input["rejected"] + context.prev_state_events = input["prev_state_events"] + + app_service_id = input["app_service_id"] + if app_service_id: + context.app_service = store.get_app_service_by_id(app_service_id) + + return context + + @defer.inlineCallbacks + def get_current_state_ids(self, store): + """Gets the current state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._current_state_ids) + + @defer.inlineCallbacks + def get_prev_state_ids(self, store): + """Gets the prev state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._prev_state_ids) + + def get_cached_current_state_ids(self): + """Gets the current state IDs if we have them already cached. + + Returns: + dict[(str, str), str]|None: Returns None if we haven't cached the + state or if state_group is None, which happens when the associated + event is an outlier. + """ + + return self._current_state_ids + + @defer.inlineCallbacks + def _fill_out_state(self, store): + """Called to populate the _current_state_ids and _prev_state_ids + attributes by loading from the database. + """ + if self.state_group is None: + return + + self._current_state_ids = yield store.get_state_ids_for_group( + self.state_group, + ) + if self._prev_state_id and self._event_state_key is not None: + self._prev_state_ids = dict(self._current_state_ids) + + key = (self._event_type, self._event_state_key) + self._prev_state_ids[key] = self._prev_state_id + else: + self._prev_state_ids = self._current_state_ids + + @defer.inlineCallbacks + def update_state(self, state_group, prev_state_ids, current_state_ids, + prev_group, delta_ids): + """Replace the state in the context + """ + + # We need to make sure we wait for any ongoing fetching of state + # to complete so that the updated state doesn't get clobbered + if self._fetching_state_deferred: + yield make_deferred_yieldable(self._fetching_state_deferred) + + self.state_group = state_group + self._prev_state_ids = prev_state_ids + self.prev_group = prev_group + self._current_state_ids = current_state_ids + self.delta_ids = delta_ids + + # We need to ensure that that we've marked as having fetched the state + self._fetching_state_deferred = defer.succeed(None) + + +def _encode_state_dict(state_dict): + """Since dicts of (type, state_key) -> event_id cannot be serialized in + JSON we need to convert them to a form that can. + """ + if state_dict is None: + return None + + return [ + (etype, state_key, v) + for (etype, state_key), v in iteritems(state_dict) + ] + + +def _decode_state_dict(input): + """Decodes a state dict encoded using `_encode_state_dict` above + """ + if input is None: + return None + + return frozendict({(etype, state_key,): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py new file mode 100644 index 0000000000..633e068eb8 --- /dev/null +++ b/synapse/events/spamcheck.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class SpamChecker(object): + def __init__(self, hs): + self.spam_checker = None + + module = None + config = None + try: + module, config = hs.config.spam_checker + except Exception: + pass + + if module is not None: + self.spam_checker = module(config=config) + + def check_event_for_spam(self, event): + """Checks if a given event is considered "spammy" by this server. + + If the server considers an event spammy, then it will be rejected if + sent by a local user. If it is sent by a user on another server, then + users receive a blank event. + + Args: + event (synapse.events.EventBase): the event to be checked + + Returns: + bool: True if the event is spammy. + """ + if self.spam_checker is None: + return False + + return self.spam_checker.check_event_for_spam(event) + + def user_may_invite(self, inviter_userid, invitee_userid, room_id): + """Checks if a given user may send an invite + + If this method returns false, the invite will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may send an invite, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + + def user_may_create_room(self, userid): + """Checks if a given user may create a room + + If this method returns false, the creation request will be rejected. + + Args: + userid (string): The sender's user ID + + Returns: + bool: True if the user may create a room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room(userid) + + def user_may_create_room_alias(self, userid, room_alias): + """Checks if a given user may create a room alias + + If this method returns false, the association request will be rejected. + + Args: + userid (string): The sender's user ID + room_alias (string): The alias to be created + + Returns: + bool: True if the user may create a room alias, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_create_room_alias(userid, room_alias) + + def user_may_publish_room(self, userid, room_id): + """Checks if a given user may publish a room to the directory + + If this method returns false, the publish request will be rejected. + + Args: + userid (string): The sender's user ID + room_id (string): The ID of the room that would be published + + Returns: + bool: True if the user may publish the room, otherwise False + """ + if self.spam_checker is None: + return True + + return self.spam_checker.user_may_publish_room(userid, room_id) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 824f4a42e3..652941ca0d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes -from . import EventBase +import re + +from six import string_types from frozendict import frozendict -import re +from synapse.api.constants import EventTypes + +from . import EventBase # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (?<!stuff) matches if the current position in the string is not preceded @@ -277,7 +280,7 @@ def serialize_event(e, time_now_ms, as_client_event=True, if only_event_fields: if (not isinstance(only_event_fields, list) or - not all(isinstance(f, basestring) for f in only_event_fields)): + not all(isinstance(f, string_types) for f in only_event_fields)): raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 2f4c8a1018..cf184748a1 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.types import EventID, RoomID, UserID -from synapse.api.errors import SynapseError +from six import string_types + from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import SynapseError +from synapse.types import EventID, RoomID, UserID class EventValidator(object): @@ -49,7 +51,7 @@ class EventValidator(object): strings.append("state_key") for s in strings: - if not isinstance(getattr(event, s), basestring): + if not isinstance(getattr(event, s), string_types): raise SynapseError(400, "Not '%s' a string type" % (s,)) if event.type == EventTypes.Member: @@ -88,5 +90,5 @@ class EventValidator(object): for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) - if not isinstance(d[s], basestring): + if not isinstance(d[s], string_types): raise SynapseError(400, "Not '%s' a string type" % (s,)) diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 2e32d245ba..f5f0bdfca3 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -15,11 +15,3 @@ """ This package includes all the federation specific logic. """ - -from .replication import ReplicationLayer - - -def initialize_http_replication(hs): - transport = hs.get_federation_transport_client() - - return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 2339cc9034..c11798093d 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -12,28 +12,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import six from twisted.internet import defer -from synapse.events.utils import prune_event - +from synapse.api.constants import MAX_DEPTH +from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash - -from synapse.api.errors import SynapseError - -from synapse.util import unwrapFirstError -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred - -import logging - +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event +from synapse.http.servlet import assert_params_in_dict +from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) class FederationBase(object): def __init__(self, hs): - pass + self.hs = hs + + self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.spam_checker = hs.get_spam_checker() + self.store = hs.get_datastore() + self._clock = hs.get_clock() @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, @@ -57,56 +61,52 @@ class FederationBase(object): """ deferreds = self._check_sigs_and_hashes(pdus) - def callback(pdu): - return pdu + @defer.inlineCallbacks + def handle_check_result(pdu, deferred): + try: + res = yield logcontext.make_deferred_yieldable(deferred) + except SynapseError: + res = None - def errback(failure, pdu): - failure.trap(SynapseError) - return None - - def try_local_db(res, pdu): if not res: # Check local db. - return self.store.get_event( + res = yield self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) - return res - def try_remote(res, pdu): if not res and pdu.origin != origin: - return self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - outlier=outlier, - timeout=10000, - ).addErrback(lambda e: None) - return res - - def warn(res, pdu): + try: + res = yield self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + outlier=outlier, + timeout=10000, + ) + except SynapseError: + pass + if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) - return res - for pdu, deferred in zip(pdus, deferreds): - deferred.addCallbacks( - callback, errback, errbackArgs=[pdu] - ).addCallback( - try_local_db, pdu - ).addCallback( - try_remote, pdu - ).addCallback( - warn, pdu - ) + defer.returnValue(res) + + handle = logcontext.preserve_fn(handle_check_result) + deferreds2 = [ + handle(pdu, deferred) + for pdu, deferred in zip(pdus, deferreds) + ] - valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( - deferreds, - consumeErrors=True - )).addErrback(unwrapFirstError) + valid_pdus = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + deferreds2, + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -114,15 +114,24 @@ class FederationBase(object): defer.returnValue([p for p in valid_pdus if p]) def _check_sigs_and_hash(self, pdu): - return self._check_sigs_and_hashes([pdu])[0] + return logcontext.make_deferred_yieldable( + self._check_sigs_and_hashes([pdu])[0], + ) def _check_sigs_and_hashes(self, pdus): - """Throws a SynapseError if a PDU does not have the correct - signatures. + """Checks that each of the received events is correctly signed by the + sending server. + + Args: + pdus (list[FrozenEvent]): the events to be checked Returns: - FrozenEvent: Either the given event or it redacted if it failed the - content hash check. + list[Deferred]: for each input event, a deferred which: + * returns the original event if the checks pass + * returns a redacted version of the event (if the signature + matched but the hash did not) + * throws a SynapseError if the signature check failed. + The deferreds run their callbacks in the sentinel logcontext. """ redacted_pdus = [ @@ -130,26 +139,38 @@ class FederationBase(object): for pdu in pdus ] - deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ + deferreds = self.keyring.verify_json_objects_for_server([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) + ctx = logcontext.LoggingContext.current_context() + def callback(_, pdu, redacted): - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted - return pdu + with logcontext.PreserveLoggingContext(ctx): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + if self.spam_checker.check_event_for_spam(pdu): + logger.warn( + "Event contains spam, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + return pdu def errback(failure, pdu): failure.trap(SynapseError) - logger.warn( - "Signature check failed for %s", - pdu.event_id, - ) + with logcontext.PreserveLoggingContext(ctx): + logger.warn( + "Signature check failed for %s", + pdu.event_id, + ) return failure for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): @@ -160,3 +181,40 @@ class FederationBase(object): ) return deferreds + + +def event_from_pdu_json(pdu_json, outlier=False): + """Construct a FrozenEvent from an event json received over federation + + Args: + pdu_json (object): pdu as received over federation + outlier (bool): True to mark this event as an outlier + + Returns: + FrozenEvent + + Raises: + SynapseError: if the pdu is missing required fields or is otherwise + not a valid matrix event + """ + # we could probably enforce a bunch of other fields here (room_id, sender, + # origin, etc etc) + assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth')) + + depth = pdu_json['depth'] + if not isinstance(depth, six.integer_types): + raise SynapseError(400, "Depth %r not an intger" % (depth, ), + Codes.BAD_JSON) + + if depth < 0: + raise SynapseError(400, "Depth too small", Codes.BAD_JSON) + elif depth > MAX_DEPTH: + raise SynapseError(400, "Depth too large", Codes.BAD_JSON) + + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 861441708b..62d7ed13cf 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -14,36 +14,35 @@ # limitations under the License. +import copy +import itertools +import logging +import random + +from six.moves import range + +from prometheus_client import Counter + from twisted.internet import defer -from .federation_base import FederationBase from synapse.api.constants import Membership - from synapse.api.errors import ( - CodeMessageException, HttpResponseException, SynapseError, + CodeMessageException, + FederationDeniedError, + HttpResponseException, + SynapseError, ) -from synapse.util import unwrapFirstError +from synapse.events import builder +from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.util import logcontext, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logutils import log_function -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent, builder -import synapse.metrics - from synapse.util.retryutils import NotRetryingDestination -import copy -import itertools -import logging -import random - - logger = logging.getLogger(__name__) - -# synapse.federation.federation_client is a silly name -metrics = synapse.metrics.get_metrics_for("synapse.federation.client") - -sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) +sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) PDU_RETRY_TIME_MS = 1 * 60 * 1000 @@ -58,6 +57,7 @@ class FederationClient(FederationBase): self._clear_tried_cache, 60 * 1000, ) self.state = hs.get_state_handler() + self.transport_layer = hs.get_federation_transport_client() def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -105,7 +105,7 @@ class FederationClient(FederationBase): a Deferred which will eventually yield a JSON object from the response """ - sent_queries_counter.inc(query_type) + sent_queries_counter.labels(query_type).inc() return self.transport_layer.make_query( destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, @@ -124,7 +124,7 @@ class FederationClient(FederationBase): a Deferred which will eventually yield a JSON object from the response """ - sent_queries_counter.inc("client_device_keys") + sent_queries_counter.labels("client_device_keys").inc() return self.transport_layer.query_client_keys( destination, content, timeout ) @@ -134,7 +134,7 @@ class FederationClient(FederationBase): """Query the device keys for a list of user ids hosted on a remote server. """ - sent_queries_counter.inc("user_devices") + sent_queries_counter.labels("user_devices").inc() return self.transport_layer.query_user_devices( destination, user_id, timeout ) @@ -151,7 +151,7 @@ class FederationClient(FederationBase): a Deferred which will eventually yield a JSON object from the response """ - sent_queries_counter.inc("client_one_time_keys") + sent_queries_counter.labels("client_one_time_keys").inc() return self.transport_layer.claim_client_keys( destination, content, timeout ) @@ -184,15 +184,15 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%s", repr(transaction_data)) pdus = [ - self.event_from_pdu_json(p, outlier=False) + event_from_pdu_json(p, outlier=False) for p in transaction_data["pdus"] ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( + pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(pdus) @@ -244,7 +244,7 @@ class FederationClient(FederationBase): logger.debug("transaction_data %r", transaction_data) pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) + event_from_pdu_json(p, outlier=outlier) for p in transaction_data["pdus"] ] @@ -252,7 +252,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] + signed_pdu = yield self._check_sigs_and_hash(pdu) break @@ -266,6 +266,9 @@ class FederationClient(FederationBase): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e.message) + continue except Exception as e: pdu_attempts[destination] = now @@ -336,11 +339,11 @@ class FederationClient(FederationBase): ) pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + event_from_pdu_json(p, outlier=True) for p in result["pdus"] ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in result.get("auth_chain", []) ] @@ -388,9 +391,9 @@ class FederationClient(FederationBase): """ if return_local: seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = seen_events.values() + signed_events = list(seen_events.values()) else: - seen_events = yield self.store.have_events(event_ids) + seen_events = yield self.store.have_seen_events(event_ids) signed_events = [] failed_to_fetch = set() @@ -409,18 +412,19 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) - for i in xrange(0, len(missing_events), batch_size): + for i in range(0, len(missing_events), batch_size): batch = set(missing_events[i:i + batch_size]) deferreds = [ - preserve_fn(self.get_pdu)( + run_in_background( + self.get_pdu, destinations=random_server_list(), event_id=e_id, ) for e_id in batch ] - res = yield preserve_context_over_deferred( + res = yield make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: @@ -441,7 +445,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in res["auth_chain"] ] @@ -570,12 +574,12 @@ class FederationClient(FederationBase): logger.debug("Got content: %s", content) state = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("state", []) ] auth_chain = [ - self.event_from_pdu_json(p, outlier=True) + event_from_pdu_json(p, outlier=True) for p in content.get("auth_chain", []) ] @@ -585,7 +589,7 @@ class FederationClient(FederationBase): } valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, pdus.values(), + destination, list(pdus.values()), outlier=True, ) @@ -650,7 +654,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - pdu = self.event_from_pdu_json(pdu_dict) + pdu = event_from_pdu_json(pdu_dict) # Check signatures are correct. pdu = yield self._check_sigs_and_hash(pdu) @@ -740,7 +744,7 @@ class FederationClient(FederationBase): ) auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -788,7 +792,7 @@ class FederationClient(FederationBase): ) events = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content.get("events", []) ] @@ -805,15 +809,6 @@ class FederationClient(FederationBase): defer.returnValue(signed_events) - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 51e3fdea06..e501251b6e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,92 +13,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import re +import six +from six import iteritems -from twisted.internet import defer - -from .federation_base import FederationBase -from .units import Transaction, Edu - -from synapse.util.async import Linearizer -from synapse.util.logutils import log_function -from synapse.util.caches.response_cache import ResponseCache -from synapse.events import FrozenEvent -from synapse.types import get_domain_from_id -import synapse.metrics +from canonicaljson import json +from prometheus_client import Counter -from synapse.api.errors import AuthError, FederationError, SynapseError +from twisted.internet import defer +from twisted.internet.abstract import isIPAddress +from twisted.python import failure +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError from synapse.crypto.event_signing import compute_event_signature +from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.persistence import TransactionActions +from synapse.federation.units import Edu, Transaction +from synapse.http.endpoint import parse_server_name +from synapse.types import get_domain_from_id +from synapse.util import async +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logutils import log_function -import simplejson as json -import logging - +# when processing incoming transactions, we try to handle multiple rooms in +# parallel, up to this limit. +TRANSACTION_CONCURRENCY_LIMIT = 10 logger = logging.getLogger(__name__) -# synapse.federation.federation_server is a silly name -metrics = synapse.metrics.get_metrics_for("synapse.federation.server") - -received_pdus_counter = metrics.register_counter("received_pdus") +received_pdus_counter = Counter("synapse_federation_server_received_pdus", "") -received_edus_counter = metrics.register_counter("received_edus") +received_edus_counter = Counter("synapse_federation_server_received_edus", "") -received_queries_counter = metrics.register_counter("received_queries", labels=["type"]) +received_queries_counter = Counter( + "synapse_federation_server_received_queries", "", ["type"] +) class FederationServer(FederationBase): + def __init__(self, hs): super(FederationServer, self).__init__(hs) self.auth = hs.get_auth() + self.handler = hs.get_handlers().federation_handler - self._server_linearizer = Linearizer("fed_server") + self._server_linearizer = async.Linearizer("fed_server") + self._transaction_linearizer = async.Linearizer("fed_txn_handler") - # We cache responses to state queries, as they take a while and often - # come in waves. - self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) + self.transaction_actions = TransactionActions(self.store) - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler + self.registry = hs.get_federation_registry() - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type - - handler is invoked as: - result = handler(args) - - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) - - self.query_handlers[query_type] = handler + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, versions, limit): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + pdus = yield self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -109,25 +90,41 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_incoming_transaction(self, transaction_data): + # keep this as early as possible to make the calculated origin ts as + # accurate as possible. + request_time = self._clock.time_msec() + transaction = Transaction(**transaction_data) - received_pdus_counter.inc_by(len(transaction.pdus)) + if not transaction.transaction_id: + raise Exception("Transaction missing transaction_id") + if not transaction.origin: + raise Exception("Transaction missing origin") - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] + logger.debug("[%s] Got transaction", transaction.transaction_id) - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] + # use a linearizer to ensure that we don't process the same transaction + # multiple times in parallel. + with (yield self._transaction_linearizer.queue( + (transaction.origin, transaction.transaction_id), + )): + result = yield self._handle_incoming_transaction( + transaction, request_time, + ) - logger.debug("[%s] Got transaction", transaction.transaction_id) + defer.returnValue(result) + + @defer.inlineCallbacks + def _handle_incoming_transaction(self, transaction, request_time): + """ Process an incoming transaction and return the HTTP response + + Args: + transaction (Transaction): incoming transaction + request_time (int): timestamp that the HTTP request arrived at + Returns: + Deferred[(int, object)]: http response code and body + """ response = yield self.transaction_actions.have_responded(transaction) if response: @@ -140,42 +137,67 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - results = [] - - for pdu in pdu_list: - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if transaction.origin != get_domain_from_id(pdu.event_id): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893). - if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) == 'join' - ): - logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, transaction.origin + received_pdus_counter.inc(len(transaction.pdus)) + + origin_host, _ = parse_server_name(transaction.origin) + + pdus_by_room = {} + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = request_time - int(p["age"]) + del p["age"] + + event = event_from_pdu_json(p) + room_id = event.room_id + pdus_by_room.setdefault(room_id, []).append(event) + + pdu_results = {} + + # we can process different rooms in parallel (which is useful if they + # require callouts to other servers to fetch missing events), but + # impose a limit to avoid going too crazy with ram/cpu. + + @defer.inlineCallbacks + def process_pdus_for_room(room_id): + logger.debug("Processing PDUs for %s", room_id) + try: + yield self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warn( + "Ignoring PDUs for room %s from banned server", room_id, + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + try: + yield self._handle_received_pdu( + transaction.origin, pdu ) - continue - else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, transaction.origin + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), ) - try: - yield self._handle_received_pdu(transaction.origin, pdu) - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") + yield async.concurrently_execute( + process_pdus_for_room, pdus_by_room.keys(), + TRANSACTION_CONCURRENCY_LIMIT, + ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): @@ -185,17 +207,16 @@ class FederationServer(FederationBase): edu.content ) - for failure in getattr(transaction, "pdu_failures", []): - logger.info("Got failure %r", failure) - - logger.debug("Returning: %s", str(results)) + pdu_failures = getattr(transaction, "pdu_failures", []) + for fail in pdu_failures: + logger.info("Got failure %r", fail) response = { - "pdus": dict(zip( - (p.event_id for p in pdu_list), results - )), + "pdus": pdu_results, } + logger.debug("Returning: %s", str(response)) + yield self.transaction_actions.set_response( transaction, 200, response @@ -205,16 +226,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - - if edu_type in self.edu_handlers: - try: - yield self.edu_handlers[edu_type](origin, content) - except SynapseError as e: - logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: - logger.exception("Failed to handle edu %r", edu_type) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) + yield self.registry.on_edu(edu_type, origin, content) @defer.inlineCallbacks @log_function @@ -222,19 +234,24 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - result = self._state_resp_cache.get((room_id, event_id)) - if not result: - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.set( - (room_id, event_id), - self._on_context_state_request_compute(room_id, event_id) - ) - else: - resp = yield result + # we grab the linearizer to protect ourselves from servers which hammer + # us. In theory we might already have the response to this query + # in the cache so we could return it without waiting for the linearizer + # - but that's non-trivial to get right, and anyway somewhat defeats + # the point of the linearizer. + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, event_id, + ) defer.returnValue((200, resp)) @@ -243,6 +260,9 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -286,7 +306,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_pdu_request(self, origin, event_id): - pdu = yield self._get_persisted_pdu(origin, event_id) + pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: defer.returnValue( @@ -302,25 +322,23 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_query_request(self, query_type, args): - received_queries_counter.inc(query_type) - - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type,)) - ) + received_queries_counter.labels(query_type).inc() + resp = yield self.registry.on_query(query_type, args) + defer.returnValue((200, resp)) @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): + def on_make_join_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @defer.inlineCallbacks def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @@ -328,7 +346,11 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -340,7 +362,9 @@ class FederationServer(FederationBase): })) @defer.inlineCallbacks - def on_make_leave_request(self, room_id, user_id): + def on_make_leave_request(self, origin, room_id, user_id): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) pdu = yield self.handler.on_make_leave_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @@ -348,7 +372,11 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) - pdu = self.event_from_pdu_json(content) + pdu = event_from_pdu_json(content) + + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, pdu.room_id) + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @@ -356,6 +384,9 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) res = { @@ -384,8 +415,11 @@ class FederationServer(FederationBase): Deferred: Results in `dict` with the same format as `content` """ with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + auth_chain = [ - self.event_from_pdu_json(e) + event_from_pdu_json(e) for e in content["auth_chain"] ] @@ -444,9 +478,9 @@ class FederationServer(FederationBase): "Claimed one-time-keys: %s", ",".join(( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in json_result.iteritems() - for device_id, device_keys in user_keys.iteritems() - for key_id, _ in device_keys.iteritems() + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) )), ) @@ -457,6 +491,9 @@ class FederationServer(FederationBase): def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): with (yield self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + yield self.check_server_matches_acl(origin_host, room_id) + logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d, min_depth: %d", @@ -485,17 +522,6 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return self.store.get_user_id_for_open_id_token(token, ts_now_ms) - @log_function - def _get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. - - Returns: - Deferred: Results in a `Pdu`. - """ - return self.handler.get_persisted_pdu( - origin, event_id, do_auth=do_auth - ) - def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. @@ -513,13 +539,57 @@ class FederationServer(FederationBase): def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. + If the event is invalid, then this method throws a FederationError. + (The error will then be logged and sent back to the sender (which + probably won't do anything with it), and other events in the + transaction will be processed as normal). + + It is likely that we'll then receive other events which refer to + this rejected_event in their prev_events, etc. When that happens, + we'll attempt to fetch the rejected event again, which will presumably + fail, so those second-generation events will also get rejected. + + Eventually, we get to the point where there are more than 10 events + between any new events and the original rejected event. Since we + only try to backfill 10 events deep on received pdu, we then accept the + new event, possibly introducing a discontinuity in the DAG, with new + forward extremities, so normal service is approximately returned, + until we try to backfill across the discontinuity. + Args: origin (str): server which sent the pdu pdu (FrozenEvent): received pdu Returns (Deferred): completes with None - Raises: FederationError if the signatures / hash do not match - """ + + Raises: FederationError if the signatures / hash do not match, or + if the event was unacceptable for any other reason (eg, too large, + too many prev_events, couldn't find the prev_events) + """ + # check that it's actually being sent from a valid destination to + # workaround bug #1753 in 0.18.5 and 0.18.6 + if origin != get_domain_from_id(pdu.event_id): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893). + if not ( + pdu.type == 'm.room.member' and + pdu.content and + pdu.content.get("membership", None) == 'join' + ): + logger.info( + "Discarding PDU %s from invalid origin %s", + pdu.event_id, origin + ) + return + else: + logger.info( + "Accepting join PDU %s from %s", + pdu.event_id, origin + ) + # Check signature. try: pdu = yield self._check_sigs_and_hash(pdu) @@ -531,20 +601,13 @@ class FederationServer(FederationBase): affected=pdu.event_id, ) - yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) + yield self.handler.on_receive_pdu( + origin, pdu, get_missing=True, sent_to_us_directly=True, + ) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - @defer.inlineCallbacks def exchange_third_party_invite( self, @@ -567,3 +630,161 @@ class FederationServer(FederationBase): origin, room_id, event_dict ) defer.returnValue(ret) + + @defer.inlineCallbacks + def check_server_matches_acl(self, server_name, room_id): + """Check if the given server is allowed by the server ACLs in the room + + Args: + server_name (str): name of server, *without any port part* + room_id (str): ID of the room to check + + Raises: + AuthError if the server does not match the ACL + """ + state_ids = yield self.store.get_current_state_ids(room_id) + acl_event_id = state_ids.get((EventTypes.ServerACL, "")) + + if not acl_event_id: + return + + acl_event = yield self.store.get_event(acl_event_id) + if server_matches_acl_event(server_name, acl_event): + return + + raise AuthError(code=403, msg="Server is banned from room") + + +def server_matches_acl_event(server_name, acl_event): + """Check if the given server is allowed by the ACL event + + Args: + server_name (str): name of server, without any port part + acl_event (EventBase): m.room.server_acl event + + Returns: + bool: True if this server is allowed by the ACLs + """ + logger.debug("Checking %s against acl %s", server_name, acl_event.content) + + # first of all, check if literal IPs are blocked, and if so, whether the + # server name is a literal IP + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warn("Ignorning non-bool allow_ip_literals flag") + allow_ip_literals = True + if not allow_ip_literals: + # check for ipv6 literals. These start with '['. + if server_name[0] == '[': + return False + + # check for ipv4 literals. We can just lift the routine from twisted. + if isIPAddress(server_name): + return False + + # next, check the deny list + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warn("Ignorning non-list deny ACL %s", deny) + deny = [] + for e in deny: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched deny rule %s", server_name, e) + return False + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warn("Ignorning non-list allow ACL %s", allow) + allow = [] + for e in allow: + if _acl_entry_matches(server_name, e): + # logger.info("%s matched allow rule %s", server_name, e) + return True + + # everything else should be rejected. + # logger.info("%s fell through", server_name) + return False + + +def _acl_entry_matches(server_name, acl_entry): + if not isinstance(acl_entry, six.string_types): + logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + return False + regex = _glob_to_regex(acl_entry) + return regex.match(server_name) + + +def _glob_to_regex(glob): + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + return re.compile(res + "\\Z", re.IGNORECASE) + + +class FederationHandlerRegistry(object): + """Allows classes to register themselves as handlers for a given EDU or + query type for incoming federation traffic. + """ + def __init__(self): + self.edu_handlers = {} + self.query_handlers = {} + + def register_edu_handler(self, edu_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation EDU of the given type. + + Args: + edu_type (str): The type of the incoming EDU to register handler for + handler (Callable[[str, dict]]): A callable invoked on incoming EDU + of the given type. The arguments are the origin server name and + the EDU contents. + """ + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (Callable[[dict], Deferred[dict]]): Invoked to handle + incoming queries of this type. The return will be yielded + on and the result used as the response to the query request. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + def on_edu(self, edu_type, origin, content): + handler = self.edu_handlers.get(edu_type) + if not handler: + logger.warn("No handler registered for EDU type %s", edu_type) + + try: + yield handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception as e: + logger.exception("Failed to handle edu %r", edu_type) + + def on_query(self, query_type, args): + handler = self.query_handlers.get(query_type) + if not handler: + logger.warn("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) + + return handler(args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 84dc606673..9146215c21 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -19,13 +19,12 @@ package. These actions are mostly only used by the :py:mod:`.replication` module. """ +import logging + from twisted.internet import defer from synapse.util.logutils import log_function -import logging - - logger = logging.getLogger(__name__) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py deleted file mode 100644 index 62d865ec4b..0000000000 --- a/synapse/federation/replication.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This layer is responsible for replicating with remote home servers using -a given transport. -""" - -from .federation_client import FederationClient -from .federation_server import FederationServer - -from .persistence import TransactionActions - -import logging - - -logger = logging.getLogger(__name__) - - -class ReplicationLayer(FederationClient, FederationServer): - """This layer is responsible for replicating with remote home servers over - the given transport. I.e., does the sending and receiving of PDUs to - remote home servers. - - The layer communicates with the rest of the server via a registered - ReplicationHandler. - - In more detail, the layer: - * Receives incoming data and processes it into transactions and pdus. - * Fetches any PDUs it thinks it might have missed. - * Keeps the current state for contexts up to date by applying the - suitable conflict resolution. - * Sends outgoing pdus wrapped in transactions. - * Fills out the references to previous pdus/transactions appropriately - for outgoing data. - """ - - def __init__(self, hs, transport_layer): - self.server_name = hs.hostname - - self.keyring = hs.get_keyring() - - self.transport_layer = transport_layer - - self.federation_client = self - - self.store = hs.get_datastore() - - self.handler = None - self.edu_handlers = {} - self.query_handlers = {} - - self._clock = hs.get_clock() - - self.transaction_actions = TransactionActions(self.store) - - self.hs = hs - - super(ReplicationLayer, self).__init__(hs) - - def __str__(self): - return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 93e5acebc1..5157c3860d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -29,23 +29,22 @@ dead worker doesn't cause the queues to grow limitlessly. Events are replicated via a separate events stream. """ -from .units import Edu +import logging +from collections import namedtuple +from six import iteritems, itervalues + +from sortedcontainers import SortedDict + +from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure -import synapse.metrics - -from blist import sorteddict -from collections import namedtuple -import logging +from .units import Edu logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) - - class FederationRemoteSendQueue(object): """A drop in replacement for TransactionQueue""" @@ -56,29 +55,27 @@ class FederationRemoteSendQueue(object): self.is_mine_id = hs.is_mine_id self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = sorteddict() # Stream position -> user_id + self.presence_changed = SortedDict() # Stream position -> user_id self.keyed_edu = {} # (destination, key) -> EDU - self.keyed_edu_changed = sorteddict() # stream position -> (destination, key) + self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) - self.edus = sorteddict() # stream position -> Edu + self.edus = SortedDict() # stream position -> Edu - self.failures = sorteddict() # stream position -> (destination, Failure) + self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = sorteddict() # stream position -> destination + self.device_messages = SortedDict() # stream position -> destination self.pos = 1 - self.pos_time = sorteddict() + self.pos_time = SortedDict() # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. def register(name, queue): - metrics.register_callback( - queue_name + "_size", - lambda: len(queue), - ) + LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,), + "", [], lambda: len(queue)) for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", @@ -101,7 +98,7 @@ class FederationRemoteSendQueue(object): now = self.clock.time_msec() keys = self.pos_time.keys() - time = keys.bisect_left(now - FIVE_MINUTES_AGO) + time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO) if not keys[:time]: return @@ -116,13 +113,13 @@ class FederationRemoteSendQueue(object): with Measure(self.clock, "send_queue._clear"): # Delete things out of presence maps keys = self.presence_changed.keys() - i = keys.bisect_left(position_to_delete) + i = self.presence_changed.bisect_left(position_to_delete) for key in keys[:i]: del self.presence_changed[key] user_ids = set( user_id - for uids in self.presence_changed.itervalues() + for uids in itervalues(self.presence_changed) for user_id in uids ) @@ -134,7 +131,7 @@ class FederationRemoteSendQueue(object): # Delete things out of keyed edus keys = self.keyed_edu_changed.keys() - i = keys.bisect_left(position_to_delete) + i = self.keyed_edu_changed.bisect_left(position_to_delete) for key in keys[:i]: del self.keyed_edu_changed[key] @@ -148,19 +145,19 @@ class FederationRemoteSendQueue(object): # Delete things out of edu map keys = self.edus.keys() - i = keys.bisect_left(position_to_delete) + i = self.edus.bisect_left(position_to_delete) for key in keys[:i]: del self.edus[key] # Delete things out of failure map keys = self.failures.keys() - i = keys.bisect_left(position_to_delete) + i = self.failures.bisect_left(position_to_delete) for key in keys[:i]: del self.failures[key] # Delete things out of device map keys = self.device_messages.keys() - i = keys.bisect_left(position_to_delete) + i = self.device_messages.bisect_left(position_to_delete) for key in keys[:i]: del self.device_messages[key] @@ -200,7 +197,7 @@ class FederationRemoteSendQueue(object): # We only want to send presence for our own users, so lets always just # filter here just in case. - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states)) self.presence_map.update({state.user_id: state for state in local_states}) self.presence_changed[pos] = [state.user_id for state in local_states] @@ -253,13 +250,12 @@ class FederationRemoteSendQueue(object): self._clear_queue_before_pos(federation_ack) # Fetch changed presence - keys = self.presence_changed.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 + i = self.presence_changed.bisect_right(from_token) + j = self.presence_changed.bisect_right(to_token) + 1 dest_user_ids = [ (pos, user_id) - for pos in keys[i:j] - for user_id in self.presence_changed[pos] + for pos, user_id_list in self.presence_changed.items()[i:j] + for user_id in user_id_list ] for (key, user_id) in dest_user_ids: @@ -268,34 +264,31 @@ class FederationRemoteSendQueue(object): ))) # Fetch changes keyed edus - keys = self.keyed_edu_changed.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 + i = self.keyed_edu_changed.bisect_right(from_token) + j = self.keyed_edu_changed.bisect_right(to_token) + 1 # We purposefully clobber based on the key here, python dict comprehensions # always use the last value, so this will correctly point to the last # stream position. - keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]} + keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in keyed_edus.iteritems(): + for ((destination, edu_key), pos) in iteritems(keyed_edus): rows.append((pos, KeyedEduRow( key=edu_key, edu=self.keyed_edu[(destination, edu_key)], ))) # Fetch changed edus - keys = self.edus.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - edus = ((k, self.edus[k]) for k in keys[i:j]) + i = self.edus.bisect_right(from_token) + j = self.edus.bisect_right(to_token) + 1 + edus = self.edus.items()[i:j] for (pos, edu) in edus: rows.append((pos, EduRow(edu))) # Fetch changed failures - keys = self.failures.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - failures = ((k, self.failures[k]) for k in keys[i:j]) + i = self.failures.bisect_right(from_token) + j = self.failures.bisect_right(to_token) + 1 + failures = self.failures.items()[i:j] for (pos, (destination, failure)) in failures: rows.append((pos, FailureRow( @@ -304,12 +297,11 @@ class FederationRemoteSendQueue(object): ))) # Fetch changed device messages - keys = self.device_messages.keys() - i = keys.bisect_right(from_token) - j = keys.bisect_right(to_token) + 1 - device_messages = {self.device_messages[k]: k for k in keys[i:j]} + i = self.device_messages.bisect_right(from_token) + j = self.device_messages.bisect_right(to_token) + 1 + device_messages = {v: k for k, v in self.device_messages.items()[i:j]} - for (destination, pos) in device_messages.iteritems(): + for (destination, pos) in iteritems(device_messages): rows.append((pos, DeviceRow( destination=destination, ))) @@ -528,19 +520,19 @@ def process_rows_for_federation(transaction_queue, rows): if buff.presence: transaction_queue.send_presence(buff.presence) - for destination, edu_map in buff.keyed_edus.iteritems(): + for destination, edu_map in iteritems(buff.keyed_edus): for key, edu in edu_map.items(): transaction_queue.send_edu( edu.destination, edu.edu_type, edu.content, key=key, ) - for destination, edu_list in buff.edus.iteritems(): + for destination, edu_list in iteritems(buff.edus): for edu in edu_list: transaction_queue.send_edu( edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in buff.failures.iteritems(): + for destination, failure_list in iteritems(buff.failures): for failure in failure_list: transaction_queue.send_failure(destination, failure) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 003eaba893..6996d6b695 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -13,34 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime +import logging -from twisted.internet import defer +from six import itervalues -from .persistence import TransactionActions -from .units import Transaction, Edu +from prometheus_client import Counter -from synapse.api.errors import HttpResponseException -from synapse.util.async import run_on_reactor -from synapse.util.logcontext import preserve_context_over_fn, preserve_fn -from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter -from synapse.util.metrics import measure_func -from synapse.handlers.presence import format_user_presence_state, get_interested_remotes -import synapse.metrics +from twisted.internet import defer -import logging +import synapse.metrics +from synapse.api.errors import FederationDeniedError, HttpResponseException +from synapse.handlers.presence import format_user_presence_state, get_interested_remotes +from synapse.metrics import ( + LaterGauge, + events_processed_counter, + sent_edus_counter, + sent_transactions_counter, +) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import logcontext +from synapse.util.metrics import measure_func +from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +from .persistence import TransactionActions +from .units import Edu, Transaction logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) - -client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = client_metrics.register_distribution( - "sent_pdu_destinations" +sent_pdus_destination_dist_count = Counter( + "synapse_federation_client_sent_pdu_destinations:count", "" +) +sent_pdus_destination_dist_total = Counter( + "synapse_federation_client_sent_pdu_destinations:total", "" ) -sent_edus_counter = client_metrics.register_counter("sent_edus") - -sent_transactions_counter = client_metrics.register_counter("sent_transactions") class TransactionQueue(object): @@ -67,8 +72,10 @@ class TransactionQueue(object): # done self.pending_transactions = {} - metrics.register_callback( - "pending_destinations", + LaterGauge( + "synapse_federation_transaction_queue_pending_destinations", + "", + [], lambda: len(self.pending_transactions), ) @@ -92,12 +99,16 @@ class TransactionQueue(object): # Map of destination -> (edu_type, key) -> Edu self.pending_edus_keyed_by_dest = edus_keyed = {} - metrics.register_callback( - "pending_pdus", + LaterGauge( + "synapse_federation_transaction_queue_pending_pdus", + "", + [], lambda: sum(map(len, pdus.values())), ) - metrics.register_callback( - "pending_edus", + LaterGauge( + "synapse_federation_transaction_queue_pending_edus", + "", + [], lambda: ( sum(map(len, edus.values())) + sum(map(len, presence.values())) @@ -146,7 +157,6 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - @defer.inlineCallbacks def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -156,12 +166,20 @@ class TransactionQueue(object): if self._is_processing: return + # fire off a processing loop in the background + run_as_background_process( + "process_event_queue_for_federation", + self._process_event_queue_loop, + ) + + @defer.inlineCallbacks + def _process_event_queue_loop(self): try: self._is_processing = True while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=20, + last_token, self._last_poked_id, limit=100, ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -169,24 +187,33 @@ class TransactionQueue(object): if not events and next_token >= self._last_poked_id: break - for event in events: + @defer.inlineCallbacks + def handle_event(event): # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.event_id) if not is_mine and send_on_behalf_of is None: - continue - - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=[ - prev_id for prev_id, _ in event.prev_events - ], - ) + return + + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = yield self.state.get_current_hosts_in_room( + event.room_id, latest_event_ids=[ + prev_id for prev_id, _ in event.prev_events + ], + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return + destinations = set(destinations) if send_on_behalf_of is not None: @@ -199,10 +226,41 @@ class TransactionQueue(object): self._send_pdu(event, destinations) + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + events_by_room = {} + for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True + )) + yield self.store.update_federation_out_pos( "events", next_token ) + if events: + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_lag.labels( + "federation_sender").set(now - ts) + synapse.metrics.event_processing_last_ts.labels( + "federation_sender").set(ts) + + events_processed_counter.inc(len(events)) + + synapse.metrics.event_processing_positions.labels( + "federation_sender").set(next_token) + finally: self._is_processing = False @@ -224,18 +282,17 @@ class TransactionQueue(object): if not destinations: return - sent_pdus_destination_dist.inc_by(len(destinations)) + sent_pdus_destination_dist_total.inc(len(destinations)) + sent_pdus_destination_dist_count.inc() for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( (pdu, order) ) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) - @preserve_fn # the caller should not yield on this + @logcontext.preserve_fn # the caller should not yield on this @defer.inlineCallbacks def send_presence(self, states): """Send the new presence states to the appropriate destinations. @@ -273,7 +330,9 @@ class TransactionQueue(object): if not states_map: break - yield self._process_presence_inner(states_map.values()) + yield self._process_presence_inner(list(states_map.values())) + except Exception: + logger.exception("Error sending presence states to servers") finally: self._processing_pending_presence = False @@ -299,7 +358,7 @@ class TransactionQueue(object): state.user_id: state for state in states }) - preserve_fn(self._attempt_new_transaction)(destination) + self._attempt_new_transaction(destination) def send_edu(self, destination, edu_type, content, key=None): edu = Edu( @@ -321,9 +380,7 @@ class TransactionQueue(object): else: self.pending_edus_by_dest.setdefault(destination, []).append(edu) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": @@ -336,9 +393,7 @@ class TransactionQueue(object): destination, [] ).append(failure) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": @@ -347,15 +402,24 @@ class TransactionQueue(object): if not self.can_send_to(destination): return - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def get_current_token(self): return 0 - @defer.inlineCallbacks def _attempt_new_transaction(self, destination): + """Try to start a new transaction to this destination + + If there is already a transaction in progress to this destination, + returns immediately. Otherwise kicks off the process of sending a + transaction in the background. + + Args: + destination (str): + + Returns: + None + """ # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending @@ -368,6 +432,16 @@ class TransactionQueue(object): ) return + logger.debug("TX [%s] Starting transaction loop", destination) + + run_as_background_process( + "federation_transaction_transmission_loop", + self._transaction_transmission_loop, + destination, + ) + + @defer.inlineCallbacks + def _transaction_transmission_loop(self, destination): pending_pdus = [] try: self.pending_transactions[destination] = 1 @@ -377,9 +451,6 @@ class TransactionQueue(object): # hence why we throw the result away. yield get_retry_limiter(destination, self.clock, self.store) - # XXX: what's this for? - yield run_on_reactor() - pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( @@ -464,6 +535,8 @@ class TransactionQueue(object): (e.retry_last_ts + e.retry_interval) / 1000.0 ), ) + except FederationDeniedError as e: + logger.info(e) except Exception as e: logger.warn( "TX [%s] Failed to send transaction: %s", diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 52b2a717d2..4529d454af 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import urllib + from twisted.internet import defer -from synapse.api.constants import Membership +from synapse.api.constants import Membership from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function -import logging - - logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state/%s/" % room_id + path = _create_path(PREFIX, "/state/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -71,7 +72,7 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state_ids/%s/" % room_id + path = _create_path(PREFIX, "/state_ids/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -93,7 +94,7 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = PREFIX + "/event/%s/" % (event_id, ) + path = _create_path(PREFIX, "/event/%s/", event_id) return self.client.get_json(destination, path=path, timeout=timeout) @log_function @@ -119,7 +120,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = PREFIX + "/backfill/%s/" % (room_id,) + path = _create_path(PREFIX, "/backfill/%s/", room_id) args = { "v": event_tuples, @@ -157,9 +158,11 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() + path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id) + response = yield self.client.put_json( transaction.destination, - path=PREFIX + "/send/%s/" % transaction.transaction_id, + path=path, data=json_data, json_data_callback=json_data_callback, long_retries=True, @@ -177,7 +180,7 @@ class TransportLayerClient(object): @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False): - path = PREFIX + "/query/%s" % query_type + path = _create_path(PREFIX, "/query/%s", query_type) content = yield self.client.get_json( destination=destination, @@ -212,6 +215,9 @@ class TransportLayerClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if the remote destination + is not in our federation whitelist """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -219,7 +225,7 @@ class TransportLayerClient(object): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) + path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -245,7 +251,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_join(self, destination, room_id, event_id, content): - path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -258,7 +264,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): - path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -277,7 +283,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): - path = PREFIX + "/invite/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -319,7 +325,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) + path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,) response = yield self.client.put_json( destination=destination, @@ -332,7 +338,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id) content = yield self.client.get_json( destination=destination, @@ -344,7 +350,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_query_auth(self, destination, room_id, event_id, content): - path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( destination=destination, @@ -406,7 +412,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/user/devices/" + user_id + path = _create_path(PREFIX, "/user/devices/%s", user_id) content = yield self.client.get_json( destination=destination, @@ -456,7 +462,7 @@ class TransportLayerClient(object): @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth, timeout): - path = PREFIX + "/get_missing_events/%s" % (room_id,) + path = _create_path(PREFIX, "/get_missing_events/%s", room_id,) content = yield self.client.post_json( destination=destination, @@ -471,3 +477,475 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + @log_function + def get_group_profile(self, destination, group_id, requester_user_id): + """Get a group profile + """ + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_profile(self, destination, group_id, requester_user_id, content): + """Update a remote group profile + + Args: + destination (str) + group_id (str) + requester_user_id (str) + content (dict): The new profile of the group + """ + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def get_group_summary(self, destination, group_id, requester_user_id): + """Get a group summary + """ + path = _create_path(PREFIX, "/groups/%s/summary", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_rooms_in_group(self, destination, group_id, requester_user_id): + """Get all rooms in a group + """ + path = _create_path(PREFIX, "/groups/%s/rooms", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def add_room_to_group(self, destination, group_id, requester_user_id, room_id, + content): + """Add a room to a group + """ + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def update_room_in_group(self, destination, group_id, requester_user_id, room_id, + config_key, content): + """Update room in group + """ + path = _create_path( + PREFIX, "/groups/%s/room/%s/config/%s", + group_id, room_id, config_key, + ) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + """Remove a room from a group + """ + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_users_in_group(self, destination, group_id, requester_user_id): + """Get users in a group + """ + path = _create_path(PREFIX, "/groups/%s/users", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_invited_users_in_group(self, destination, group_id, requester_user_id): + """Get users that have been invited to a group + """ + path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def accept_group_invite(self, destination, group_id, user_id, content): + """Accept a group invite + """ + path = _create_path( + PREFIX, "/groups/%s/users/%s/accept_invite", + group_id, user_id, + ) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def join_group(self, destination, group_id, user_id, content): + """Attempts to join a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + """Invite a user to a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def invite_to_group_notification(self, destination, group_id, user_id, content): + """Sent by group server to inform a user's server that they have been + invited. + """ + + path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group(self, destination, group_id, requester_user_id, + user_id, content): + """Remove a user fron a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def remove_user_from_group_notification(self, destination, group_id, user_id, + content): + """Sent by group server to inform a user's server that they have been + kicked from the group. + """ + + path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def renew_group_attestation(self, destination, group_id, user_id, content): + """Sent by either a group server or a user's server to periodically update + the attestations + """ + + path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_room(self, destination, group_id, user_id, room_id, + category_id, content): + """Update a room entry in a group summary + """ + if category_id: + path = _create_path( + PREFIX, "/groups/%s/summary/categories/%s/rooms/%s", + group_id, category_id, room_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_room(self, destination, group_id, user_id, room_id, + category_id): + """Delete a room entry in a group summary + """ + if category_id: + path = _create_path( + PREFIX + "/groups/%s/summary/categories/%s/rooms/%s", + group_id, category_id, room_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_categories(self, destination, group_id, requester_user_id): + """Get all categories in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_category(self, destination, group_id, requester_user_id, category_id): + """Get category info in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_category(self, destination, group_id, requester_user_id, category_id, + content): + """Update a category in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_category(self, destination, group_id, requester_user_id, + category_id): + """Delete a category in a group + """ + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_roles(self, destination, group_id, requester_user_id): + """Get all roles in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles", group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def get_group_role(self, destination, group_id, requester_user_id, role_id): + """Get a roles info + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_role(self, destination, group_id, requester_user_id, role_id, + content): + """Update a role in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_role(self, destination, group_id, requester_user_id, role_id): + """Delete a role in a group + """ + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + @log_function + def update_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id, content): + """Update a users entry in a group + """ + if role_id: + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", + group_id, role_id, user_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + + return self.client.post_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def set_group_join_policy(self, destination, group_id, requester_user_id, + content): + """Sets the join policy for a group + """ + path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,) + + return self.client.put_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function + def delete_group_summary_user(self, destination, group_id, requester_user_id, + user_id, role_id): + """Delete a users entry in a group + """ + if role_id: + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", + group_id, role_id, user_id, + ) + else: + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) + + return self.client.delete_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + + def bulk_get_publicised_groups(self, destination, user_ids): + """Get the groups a list of users are publicising + """ + + path = PREFIX + "/get_groups_publicised" + + content = {"user_ids": user_ids} + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + +def _create_path(prefix, path, *args): + """Creates a path from the prefix, path template and args. Ensures that + all args are url encoded. + + Example: + + _create_path(PREFIX, "/event/%s/", event_id) + + Args: + prefix (str) + path (str): String template for the path + args: ([str]): Args to insert into path. Each arg will be url encoded + + Returns: + str + """ + return prefix + path % tuple(urllib.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a78f01e442..c9beca27c2 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,25 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging +import re + from twisted.internet import defer +import synapse +from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError +from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( - parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_boolean_from_args, + parse_integer_from_args, + parse_json_object_from_request, + parse_string_from_args, ) +from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.logcontext import run_in_background from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string -from synapse.util.logcontext import preserve_fn -from synapse.types import ThirdPartyInstanceID - -import functools -import logging -import re -import synapse - logger = logging.getLogger(__name__) @@ -81,6 +84,7 @@ class Authenticator(object): self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks @@ -97,26 +101,6 @@ class Authenticator(object): origin = None - def parse_auth_header(header_str): - try: - params = auth.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith("\""): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return (origin, key, sig) - except: - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: @@ -125,11 +109,17 @@ class Authenticator(object): ) for auth in auth_headers: - if auth.startswith("X-Matrix"): - (origin, key, sig) = parse_auth_header(auth) + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig + if ( + self.federation_domain_whitelist is not None and + origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + if not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, @@ -144,11 +134,60 @@ class Authenticator(object): # alive retry_timings = yield self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings["retry_last_ts"]: - logger.info("Marking origin %r as up", origin) - preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0) + run_in_background(self._reset_retry_timings, origin) defer.returnValue(origin) + @defer.inlineCallbacks + def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + yield self.store.set_destination_retry_timings(origin, 0, 0) + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + + +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode('utf-8') + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith(b"\""): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warn( + "Error parsing auth header '%s': %s", + header_bytes.decode('ascii', 'replace'), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + ) + class BaseFederationServlet(object): REQUIRE_AUTH = True @@ -177,7 +216,7 @@ class BaseFederationServlet(object): if self.REQUIRE_AUTH: logger.exception("authenticate_request failed") raise - except: + except Exception: logger.exception("authenticate_request failed") raise @@ -270,7 +309,7 @@ class FederationSendServlet(BaseFederationServlet): code, response = yield self.handler.on_incoming_transaction( transaction_data ) - except: + except Exception: logger.exception("on_incoming_transaction failed") raise @@ -347,7 +386,9 @@ class FederationMakeJoinServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_join_request(context, user_id) + content = yield self.handler.on_make_join_request( + origin, context, user_id, + ) defer.returnValue((200, content)) @@ -356,7 +397,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request(context, user_id) + content = yield self.handler.on_make_leave_request( + origin, context, user_id, + ) defer.returnValue((200, content)) @@ -609,6 +652,549 @@ class FederationVersionServlet(BaseFederationServlet): })) +class FederationGroupsProfileServlet(BaseFederationServlet): + """Get/set the basic profile of a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/profile$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_profile( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryServlet(BaseFederationServlet): + PATH = "/groups/(?P<group_id>[^/]*)/summary$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_group_summary( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRoomsServlet(BaseFederationServlet): + """Get the rooms in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/rooms$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_rooms_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsServlet(BaseFederationServlet): + """Add/remove room from group + """ + PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): + """Update room config in group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, room_id, config_key): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class FederationGroupsUsersServlet(BaseFederationServlet): + """Get the users in a group on behalf of a user + """ + PATH = "/groups/(?P<group_id>[^/]*)/users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInvitedUsersServlet(BaseFederationServlet): + """Get the users that have been invited to a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/invited_users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsInviteServlet(BaseFederationServlet): + """Ask a group server to invite someone to the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsAcceptInviteServlet(BaseFederationServlet): + """Accept an invitation from the group server + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.accept_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsJoinServlet(BaseFederationServlet): + """Attempt to join a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.join_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveUserServlet(BaseFederationServlet): + """Leave or kick a user from the group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsLocalInviteServlet(BaseFederationServlet): + """A group server has invited a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + new_content = yield self.handler.on_invite( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): + """A group server has removed a local user + """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.user_removed_from_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation + """ + PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + # We don't need to check auth here as we check the attestation signatures + + new_content = yield self.handler.on_renew_attestation( + group_id, user_id, content + ) + + defer.returnValue((200, new_content)) + + +class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoriesServlet(BaseFederationServlet): + """Get all categories for a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsCategoryServlet(BaseFederationServlet): + """Add/remove/get a category in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, category_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = yield self.handler.delete_group_category( + group_id, requester_user_id, category_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRolesServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsRoleServlet(BaseFederationServlet): + """Add/remove/get a role in a group + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = yield self.handler.get_group_role( + group_id, requester_user_id, role_id + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_role( + group_id, requester_user_id, role_id, content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_role( + group_id, requester_user_id, role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsSummaryUsersServlet(BaseFederationServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATH = ( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = yield self.handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): + """Get roles in a group + """ + PATH = ( + "/get_groups_publicised$" + ) + + @defer.inlineCallbacks + def on_POST(self, origin, content, query): + resp = yield self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False, + ) + + defer.returnValue((200, resp)) + + +class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): + """Sets whether a group is joinable without an invite or knock + """ + PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationPullServlet, @@ -635,15 +1221,49 @@ FEDERATION_SERVLET_CLASSES = ( FederationVersionServlet, ) + ROOM_LIST_CLASSES = ( PublicRoomList, ) +GROUP_SERVER_SERVLET_CLASSES = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) + + +GROUP_LOCAL_SERVLET_CLASSES = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) + + +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) + def register_servlets(hs, resource, authenticator, ratelimiter): for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_replication_layer(), + handler=hs.get_federation_server(), authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -656,3 +1276,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter): ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) + + for servletclass in GROUP_SERVER_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_server_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_LOCAL_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_local_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_attestation_renewer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 3f645acc43..bb1b3b13f7 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,10 +17,9 @@ server protocol. """ -from synapse.util.jsonobject import JsonEncodedObject - import logging +from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) @@ -74,8 +73,6 @@ class Transaction(JsonEncodedObject): "previous_ids", "pdus", "edus", - "transaction_id", - "destination", "pdu_failures", ] diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/groups/__init__.py diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py new file mode 100644 index 0000000000..47452700a8 --- /dev/null +++ b/synapse/groups/attestations.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attestations ensure that users and groups can't lie about their memberships. + +When a user joins a group the HS and GS swap attestations, which allow them +both to independently prove to third parties their membership.These +attestations have a validity period so need to be periodically renewed. + +If a user leaves (or gets kicked out of) a group, either side can still use +their attestation to "prove" their membership, until the attestation expires. +Therefore attestations shouldn't be relied on to prove membership in important +cases, but can for less important situtations, e.g. showing a users membership +of groups on their profile, showing flairs, etc. + +An attestation is a signed blob of json that looks like: + + { + "user_id": "@foo:a.example.com", + "group_id": "+bar:b.example.com", + "valid_until_ms": 1507994728530, + "signatures":{"matrix.org":{"ed25519:auto":"..."}} + } +""" + +import logging +import random + +from signedjson.sign import sign_json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id +from synapse.util.logcontext import run_in_background + +logger = logging.getLogger(__name__) + + +# Default validity duration for new attestations we create +DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 + +# We add some jitter to the validity duration of attestations so that if we +# add lots of users at once we don't need to renew them all at once. +# The jitter is a multiplier picked randomly between the first and second number +DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) + +# Start trying to update our attestations when they come this close to expiring +UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 + + +class GroupAttestationSigning(object): + """Creates and verifies group attestations. + """ + def __init__(self, hs): + self.keyring = hs.get_keyring() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.signing_key = hs.config.signing_key[0] + + @defer.inlineCallbacks + def verify_attestation(self, attestation, group_id, user_id, server_name=None): + """Verifies that the given attestation matches the given parameters. + + An optional server_name can be supplied to explicitly set which server's + signature is expected. Otherwise assumes that either the group_id or user_id + is local and uses the other's server as the one to check. + """ + + if not server_name: + if get_domain_from_id(group_id) == self.server_name: + server_name = get_domain_from_id(user_id) + elif get_domain_from_id(user_id) == self.server_name: + server_name = get_domain_from_id(group_id) + else: + raise Exception("Expected either group_id or user_id to be local") + + if user_id != attestation["user_id"]: + raise SynapseError(400, "Attestation has incorrect user_id") + + if group_id != attestation["group_id"]: + raise SynapseError(400, "Attestation has incorrect group_id") + valid_until_ms = attestation["valid_until_ms"] + + # TODO: We also want to check that *new* attestations that people give + # us to store are valid for at least a little while. + if valid_until_ms < self.clock.time_msec(): + raise SynapseError(400, "Attestation expired") + + yield self.keyring.verify_json_for_server(server_name, attestation) + + def create_attestation(self, group_id, user_id): + """Create an attestation for the group_id and user_id with default + validity length. + """ + validity_period = DEFAULT_ATTESTATION_LENGTH_MS + validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + valid_until_ms = int(self.clock.time_msec() + validity_period) + + return sign_json({ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, self.server_name, self.signing_key) + + +class GroupAttestionRenewer(object): + """Responsible for sending and receiving attestation updates. + """ + + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.assestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.is_mine_id = hs.is_mine_id + self.attestations = hs.get_groups_attestation_signing() + + self._renew_attestations_loop = self.clock.looping_call( + self._renew_attestations, 30 * 60 * 1000, + ) + + @defer.inlineCallbacks + def on_renew_attestation(self, group_id, user_id, content): + """When a remote updates an attestation + """ + attestation = content["attestation"] + + if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): + raise SynapseError(400, "Neither user not group are on this server") + + yield self.attestations.verify_attestation( + attestation, + user_id=user_id, + group_id=group_id, + ) + + yield self.store.update_remote_attestion(group_id, user_id, attestation) + + defer.returnValue({}) + + @defer.inlineCallbacks + def _renew_attestations(self): + """Called periodically to check if we need to update any of our attestations + """ + + now = self.clock.time_msec() + + rows = yield self.store.get_attestations_need_renewals( + now + UPDATE_ATTESTATION_TIME_MS + ) + + @defer.inlineCallbacks + def _renew_attestation(group_id, user_id): + try: + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): + destination = get_domain_from_id(user_id) + else: + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) + + yield self.transport_client.renew_group_attestation( + destination, group_id, user_id, + content={"attestation": attestation}, + ) + + yield self.store.update_attestation_renewal( + group_id, user_id, attestation + ) + except Exception: + logger.exception("Error renewing attestation of %r in %r", + user_id, group_id) + + for row in rows: + group_id = row["group_id"] + user_id = row["user_id"] + + run_in_background(_renew_attestation, group_id, user_id) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py new file mode 100644 index 0000000000..633c865ed8 --- /dev/null +++ b/synapse/groups/groups_server.py @@ -0,0 +1,953 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import string_types + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import GroupID, RoomID, UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + + +# TODO: Allow users to "knock" or simpkly join depending on rules +# TODO: Federation admin APIs +# TODO: is_priveged flag to users and is_public to users and rooms +# TODO: Audit log for admins (profile updates, membership changes, users who tried +# to join but were rejected, etc) +# TODO: Flairs + + +class GroupsServerHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.attestations = hs.get_groups_attestation_signing() + self.transport_client = hs.get_federation_transport_client() + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + @defer.inlineCallbacks + def check_group_is_ours(self, group_id, requester_user_id, + and_exists=False, and_is_admin=None): + """Check that the group is ours, and optionally if it exists. + + If group does exist then return group. + + Args: + group_id (str) + and_exists (bool): whether to also check if group exists + and_is_admin (str): whether to also check if given str is a user_id + that is an admin + """ + if not self.is_mine_id(group_id): + raise SynapseError(400, "Group not on this server") + + group = yield self.store.get_group(group_id) + if and_exists and not group: + raise SynapseError(404, "Unknown group") + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + if group and not is_user_in_group and not group["is_public"]: + raise SynapseError(404, "Unknown group") + + if and_is_admin: + is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + defer.returnValue(group) + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the summary for a group as seen by requester_user_id. + + The group summary consists of the profile of the room, and a curated + list of users and rooms. These list *may* be organised by role/category. + The roles/categories are ordered, and so are the users/rooms within them. + + A user/room may appear in multiple roles/categories. + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + profile = yield self.get_group_profile(group_id, requester_user_id) + + users, roles = yield self.store.get_users_for_summary_by_role( + group_id, include_private=is_user_in_group, + ) + + # TODO: Add profiles to users + + rooms, categories = yield self.store.get_rooms_for_summary_by_category( + group_id, include_private=is_user_in_group, + ) + + for room_entry in rooms: + room_id = room_entry["room_id"] + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + entry = dict(entry) # so we don't change whats cached + entry.pop("room_id", None) + + room_entry["profile"] = entry + + rooms.sort(key=lambda e: e.get("order", 0)) + + for entry in users: + user_id = entry["user_id"] + + if not self.is_mine_id(requester_user_id): + attestation = yield self.store.get_remote_attestation(group_id, user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, user_id, + ) + + user_profile = yield self.profile_handler.get_profile_from_cache(user_id) + entry.update(user_profile) + + users.sort(key=lambda e: e.get("order", 0)) + + membership_info = yield self.store.get_users_membership_info_in_group( + group_id, requester_user_id, + ) + + defer.returnValue({ + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + }) + + @defer.inlineCallbacks + def update_group_summary_room(self, group_id, requester_user_id, + room_id, category_id, content): + """Add/update a room to the group summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + RoomID.from_string(room_id) # Ensure valid room id + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_room(self, group_id, requester_user_id, + room_id, category_id): + """Remove a room from the summary + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_room_from_summary( + group_id=group_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def set_group_join_policy(self, group_id, requester_user_id, content): + """Sets the group join policy. + + Currently supported policies are: + - "invite": an invite must be received and accepted in order to join. + - "open": anyone can join. + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + join_policy = _parse_join_policy_from_contents(content) + if join_policy is None: + raise SynapseError( + 400, "No value specified for 'm.join_policy'" + ) + + yield self.store.set_group_join_policy(group_id, join_policy=join_policy) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_categories(self, group_id, requester_user_id): + """Get all categories in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = yield self.store.get_group_categories( + group_id=group_id, + ) + defer.returnValue({"categories": categories}) + + @defer.inlineCallbacks + def get_group_category(self, group_id, requester_user_id, category_id): + """Get a specific category in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_category(self, group_id, requester_user_id, category_id, content): + """Add/Update a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + profile = content.get("profile") + + yield self.store.upsert_group_category( + group_id=group_id, + category_id=category_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_category(self, group_id, requester_user_id, category_id): + """Delete a group category + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id + ) + + yield self.store.remove_group_category( + group_id=group_id, + category_id=category_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_roles(self, group_id, requester_user_id): + """Get all roles in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = yield self.store.get_group_roles( + group_id=group_id, + ) + defer.returnValue({"roles": roles}) + + @defer.inlineCallbacks + def get_group_role(self, group_id, requester_user_id, role_id): + """Get a specific role in a group (as seen by user) + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = yield self.store.get_group_role( + group_id=group_id, + role_id=role_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def update_group_role(self, group_id, requester_user_id, role_id, content): + """Add/update a role in a group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + is_public = _parse_visibility_from_contents(content) + + profile = content.get("profile") + + yield self.store.upsert_group_role( + group_id=group_id, + role_id=role_id, + is_public=is_public, + profile=profile, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_role(self, group_id, requester_user_id, role_id): + """Remove role from group + """ + yield self.check_group_is_ours( + group_id, + requester_user_id, + and_exists=True, + and_is_admin=requester_user_id, + ) + + yield self.store.remove_group_role( + group_id=group_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, + content): + """Add/update a users entry in the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + order = content.get("order", None) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + order=order, + is_public=is_public, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): + """Remove a user from the group summary + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + yield self.store.remove_user_from_summary( + group_id=group_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def get_group_profile(self, group_id, requester_user_id): + """Get the group profile as seen by requester_user_id + """ + + yield self.check_group_is_ours(group_id, requester_user_id) + + group = yield self.store.get_group(group_id) + + if group: + cols = [ + "name", "short_description", "long_description", + "avatar_url", "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" + + defer.returnValue(group_description) + else: + raise SynapseError(404, "Unknown group") + + @defer.inlineCallbacks + def update_group_profile(self, group_id, requester_user_id, content): + """Update the group profile + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + ) + + profile = {} + for keyname in ("name", "avatar_url", "short_description", + "long_description"): + if keyname in content: + value = content[keyname] + if not isinstance(value, string_types): + raise SynapseError(400, "%r value is not a string" % (keyname,)) + profile[keyname] = value + + yield self.store.update_group_profile(group_id, profile) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + user_results = yield self.store.get_users_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = yield self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id, + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + defer.returnValue({ + "chunk": chunk, + "total_user_count_estimate": len(user_results), + }) + + @defer.inlineCallbacks + def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = yield self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = { + "user_id": user_id + } + try: + profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warn("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + defer.returnValue({ + "chunk": user_profiles, + "total_user_count_estimate": len(invited_users), + }) + + @defer.inlineCallbacks + def get_rooms_in_group(self, group_id, requester_user_id): + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + room_results = yield self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group, + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = yield self.store.get_users_in_room(room_id) + entry = yield self.room_list_handler.generate_room_entry( + room_id, len(joined_users), + with_alias=False, allow_private=True, + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + defer.returnValue({ + "chunk": chunk, + "total_room_count_estimate": len(room_results), + }) + + @defer.inlineCallbacks + def add_room_to_group(self, group_id, requester_user_id, room_id, content): + """Add room to group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + + defer.returnValue({}) + + @defer.inlineCallbacks + def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, + content): + """Update room in group + """ + RoomID.from_string(room_id) # Ensure valid room id + + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + if config_key == "m.visibility": + is_public = _parse_visibility_dict(content) + + yield self.store.update_room_in_group_visibility( + group_id, room_id, + is_public=is_public, + ) + else: + raise SynapseError(400, "Uknown config option") + + defer.returnValue({}) + + @defer.inlineCallbacks + def remove_room_from_group(self, group_id, requester_user_id, room_id): + """Remove room from group + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + yield self.store.remove_room_from_group(group_id, room_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite_to_group(self, group_id, user_id, requester_user_id, content): + """Invite user to group + """ + + group = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + # TODO: Check if user knocked + # TODO: Check if user is already invited + + content = { + "profile": { + "name": group["name"], + "avatar_url": group["avatar_url"], + }, + "inviter": requester_user_id, + } + + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + res = yield groups_local.on_invite(group_id, user_id, content) + local_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content.update({ + "attestation": local_attestation, + }) + + res = yield self.transport_client.invite_to_group_notification( + get_domain_from_id(user_id), group_id, user_id, content + ) + + user_profile = res.get("user_profile", {}) + yield self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + if res["state"] == "join": + if not self.hs.is_mine_id(user_id): + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=user_id, + group_id=group_id, + ) + else: + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, user_id, + is_admin=False, + is_public=False, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + elif res["state"] == "invite": + yield self.store.add_group_invite( + group_id, user_id, + ) + defer.returnValue({ + "state": "invite" + }) + elif res["state"] == "reject": + defer.returnValue({ + "state": "reject" + }) + else: + raise SynapseError(502, "Unknown state returned by HS") + + @defer.inlineCallbacks + def _add_user(self, group_id, user_id, content): + """Add a user to a group based on a content dict. + + See accept_invite, join_group. + """ + if not self.hs.is_mine_id(user_id): + local_attestation = self.attestations.create_attestation( + group_id, user_id, + ) + + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=user_id, + group_id=group_id, + ) + else: + local_attestation = None + remote_attestation = None + + is_public = _parse_visibility_from_contents(content) + + yield self.store.add_user_to_group( + group_id, user_id, + is_admin=False, + is_public=is_public, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + defer.returnValue(local_attestation) + + @defer.inlineCallbacks + def accept_invite(self, group_id, requester_user_id, content): + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def join_group(self, group_id, requester_user_id, content): + """User tries to join the group. + + This will error if the group requires an invite/knock to join + """ + + group_info = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True + ) + if group_info['join_policy'] != "open": + raise SynapseError(403, "Group is not publicly joinable") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def knock(self, group_id, requester_user_id, content): + """A user requests becoming a member of the group + """ + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def accept_knock(self, group_id, requester_user_id, content): + """Accept a users knock to the room. + + Errors if the user hasn't knocked, rather than inviting them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + raise NotImplementedError() + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from the group; either a user is leaving or an admin + kicked them. + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_kick = False + if requester_user_id != user_id: + is_admin = yield self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + if not is_admin: + raise SynapseError(403, "User is not admin in group") + + is_kick = True + + yield self.store.remove_user_from_group( + group_id, user_id, + ) + + if is_kick: + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + yield groups_local.user_removed_from_group(group_id, user_id, {}) + else: + yield self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + + if not self.hs.is_mine_id(user_id): + yield self.store.maybe_delete_remote_profile_cache(user_id) + + defer.returnValue({}) + + @defer.inlineCallbacks + def create_group(self, group_id, requester_user_id, content): + group = yield self.check_group_is_ours(group_id, requester_user_id) + + logger.info("Attempting to create group with ID: %r", group_id) + + # parsing the id into a GroupID validates it. + group_id_obj = GroupID.from_string(group_id) + + if group: + raise SynapseError(400, "Group already exists") + + is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + if not is_admin: + if not self.hs.config.enable_group_creation: + raise SynapseError( + 403, "Only a server admin can create groups on this server", + ) + localpart = group_id_obj.localpart + if not localpart.startswith(self.hs.config.group_creation_prefix): + raise SynapseError( + 400, + "Can only create groups with prefix %r on this server" % ( + self.hs.config.group_creation_prefix, + ), + ) + + profile = content.get("profile", {}) + name = profile.get("name") + avatar_url = profile.get("avatar_url") + short_description = profile.get("short_description") + long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) + + yield self.store.create_group( + group_id, + requester_user_id, + name=name, + avatar_url=avatar_url, + short_description=short_description, + long_description=long_description, + ) + + if not self.hs.is_mine_id(requester_user_id): + remote_attestation = content["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + user_id=requester_user_id, + group_id=group_id, + ) + + local_attestation = self.attestations.create_attestation( + group_id, + requester_user_id, + ) + else: + local_attestation = None + remote_attestation = None + + yield self.store.add_user_to_group( + group_id, requester_user_id, + is_admin=True, + is_public=True, # TODO + local_attestation=local_attestation, + remote_attestation=remote_attestation, + ) + + if not self.hs.is_mine_id(requester_user_id): + yield self.store.add_remote_profile_cache( + requester_user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + + defer.returnValue({ + "group_id": group_id, + }) + + +def _parse_join_policy_from_contents(content): + """Given a content for a request, return the specified join policy or None + """ + + join_policy_dict = content.get("m.join_policy") + if join_policy_dict: + return _parse_join_policy_dict(join_policy_dict) + else: + return None + + +def _parse_join_policy_dict(join_policy_dict): + """Given a dict for the "m.join_policy" config return the join policy specified + """ + join_policy_type = join_policy_dict.get("type") + if not join_policy_type: + return "invite" + + if join_policy_type not in ("invite", "open"): + raise SynapseError( + 400, "Synapse only supports 'invite'/'open' join rule" + ) + return join_policy_type + + +def _parse_visibility_from_contents(content): + """Given a content for a request parse out whether the entity should be + public or not + """ + + visibility = content.get("m.visibility") + if visibility: + return _parse_visibility_dict(visibility) + else: + is_public = True + + return is_public + + +def _parse_visibility_dict(visibility): + """Given a dict for the "m.visibility" config return if the entity should + be public or not + """ + vis_type = visibility.get("type") + if not vis_type: + return True + + if vis_type not in ("public", "private"): + raise SynapseError( + 400, "Synapse only supports 'public'/'private' visibility" + ) + return vis_type == "public" diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 5ad408f549..413425fed1 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,17 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .register import RegistrationHandler -from .room import ( - RoomCreationHandler, RoomContextHandler, -) -from .room_member import RoomMemberHandler -from .message import MessageHandler -from .federation import FederationHandler -from .profile import ProfileHandler -from .directory import DirectoryHandler from .admin import AdminHandler +from .directory import DirectoryHandler +from .federation import FederationHandler from .identity import IdentityHandler +from .register import RegistrationHandler from .search import SearchHandler @@ -48,13 +42,8 @@ class Handlers(object): def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) - self.message_handler = MessageHandler(hs) - self.room_creation_handler = RoomCreationHandler(hs) - self.room_member_handler = RoomMemberHandler(hs) self.federation_handler = FederationHandler(hs) - self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) - self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index faa5609c0c..704181d2d3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -18,11 +18,10 @@ import logging from twisted.internet import defer import synapse.types -from synapse.api.constants import Membership, EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import LimitExceededError from synapse.types import UserID - logger = logging.getLogger(__name__) @@ -113,15 +112,16 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: + current_state_ids = yield context.get_current_state_ids(self.store) current_state = yield self.store.get_events( - context.current_state_ids.values() + list(current_state_ids.values()) ) else: current_state = yield self.state_handler.get_current_state( event.room_id ) - current_state = current_state.values() + current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @@ -158,7 +158,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester( target_user, is_guest=True) - handler = self.hs.get_handlers().room_member_handler + handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, target_user, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f36b358b45..5d629126fc 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 05af54d31b..ee41aed69e 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -13,16 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from six import itervalues + +from prometheus_client import Counter + from twisted.internet import defer +import synapse from synapse.api.constants import EventTypes +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.metrics import Measure -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred - -import logging logger = logging.getLogger(__name__) +events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") + def log_failure(failure): logger.error( @@ -70,21 +78,25 @@ class ApplicationServicesHandler(object): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: - upper_bound = self.current_max limit = 100 while True: upper_bound, events = yield self.store.get_new_events_for_appservice( - upper_bound, limit + self.current_max, limit ) if not events: break + events_by_room = {} for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + @defer.inlineCallbacks + def handle_event(event): # Gather interested services services = yield self._get_services_for_event(event) if len(services) == 0: - continue # no services need notifying + return # no services need notifying # Do we know this user exists? If not, poke the user # query API for all services which match that user regex. @@ -95,19 +107,39 @@ class ApplicationServicesHandler(object): yield self._check_user_exists(event.state_key) if not self.started_scheduler: - self.scheduler.start().addErrback(log_failure) + def start_scheduler(): + return self.scheduler.start().addErrback(log_failure) + run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True # Fork off pushes to these services for service in services: - preserve_fn(self.scheduler.submit_event_for_as)( - service, event - ) + self.scheduler.submit_event_for_as(service, event) + + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + yield make_deferred_yieldable(defer.gatherResults([ + run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], consumeErrors=True)) yield self.store.set_appservice_last_pos(upper_bound) - if len(events) < limit: - break + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_positions.labels( + "appservice_sender").set(upper_bound) + + events_processed_counter.inc(len(events)) + + synapse.metrics.event_processing_lag.labels( + "appservice_sender").set(now - ts) + synapse.metrics.event_processing_last_ts.labels( + "appservice_sender").set(ts) finally: self.is_processing = False @@ -163,8 +195,11 @@ class ApplicationServicesHandler(object): def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield preserve_context_over_deferred(defer.DeferredList([ - preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + results = yield make_deferred_yieldable(defer.DeferredList([ + run_in_background( + self.appservice_api.query_3pe, + service, kind, protocol, fields, + ) for service in services ], consumeErrors=True)) @@ -225,11 +260,15 @@ class ApplicationServicesHandler(object): event based on the service regex. """ services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - yield s.is_interested(event, self.store) - ) - ] + + # we can't use a list comprehension here. Since python 3, list + # comprehensions use a generator internally. This means you can't yield + # inside of a list comprehension anymore. + interested_list = [] + for s in services: + if (yield s.is_interested(event, self.store)): + interested_list.append(s) + defer.returnValue(interested_list) def _get_services_for_user(self, user_id): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b00446bec0..402e44cdef 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -14,24 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from ._base import BaseHandler -from synapse.api.constants import LoginType -from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError -from synapse.util.async import run_on_reactor -from synapse.util.caches.expiringcache import ExpiringCache - -from twisted.web.client import PartialDownloadError - import logging + +import attr import bcrypt import pymacaroons -import simplejson +from canonicaljson import json + +from twisted.internet import defer, threads +from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils +from synapse.api.constants import LoginType +from synapse.api.errors import ( + AuthError, + Codes, + InteractiveAuthIncompleteError, + LoginError, + StoreError, + SynapseError, +) +from synapse.module_api import ModuleApi +from synapse.types import UserID +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -46,7 +54,6 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) self.checkers = { - LoginType.PASSWORD: self._check_password_auth, LoginType.RECAPTCHA: self._check_recaptcha, LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, @@ -63,10 +70,7 @@ class AuthHandler(BaseHandler): reset_expiry_on_get=True, ) - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists - ) - + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -75,39 +79,120 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + # we keep this as a list despite the O(N^2) implication so that we can + # keep PASSWORD first and avoid confusing clients which pick the first + # type in the list. (NB that the spec doesn't require us to do so and + # clients which favour types that they don't understand over those that + # they do are technically broken) + login_types = [] + if self._password_enabled: + login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + for t in provider.get_supported_login_types().keys(): + if t not in login_types: + login_types.append(t) + self._supported_login_types = login_types + + @defer.inlineCallbacks + def validate_user_via_ui_auth(self, requester, request_body, clientip): + """ + Checks that the user is who they claim to be, via a UI auth. + + This is used for things like device deletion and password reset where + the user already has a valid access token, but we want to double-check + that it isn't stolen by re-authenticating them. + + Args: + requester (Requester): The user, as given by the access token + + request_body (dict): The body of the request sent by the client + + clientip (str): The IP address of the client. + + Returns: + defer.Deferred[dict]: the parameters for this request (which may + have been given only in a previous call). + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + any of the permitted login flows + + AuthError if the client has completed a login flow, and it gives + a different user to `requester` + """ + + # build a list of supported flows + flows = [ + [login_type] for login_type in self._supported_login_types + ] + + result, params, _ = yield self.check_auth( + flows, request_body, clientip, + ) + + # find the completed login type + for login_type in self._supported_login_types: + if login_type not in result: + continue + + user_id = result[login_type] + break + else: + # this can't happen + raise Exception( + "check_auth returned True but no successful login type", + ) + + # check that the UI auth matched the access token + if user_id != requester.user.to_string(): + raise AuthError(403, "Invalid auth") + + defer.returnValue(params) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'creds' contains the authenticated credentials of each stage. - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'params' contains the parameters for this request (which may + have been given only in a previous call). - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -135,11 +220,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -150,14 +232,12 @@ class AuthHandler(BaseHandler): errordict = {} if 'type' in authdict: login_type = authdict['type'] - if login_type not in self.checkers: - raise LoginError(400, "", Codes.UNRECOGNIZED) try: - result = yield self.checkers[login_type](authdict, clientip) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) - except LoginError, e: + except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new # validation token (thus sending a new email) each time it @@ -166,14 +246,14 @@ class AuthHandler(BaseHandler): # # Grandfather in the old behaviour for now to avoid # breaking old riot deployments. - raise e + raise # this step failed. Merge the error dict into the response # so that the client can have another go. errordict = e.error_dict() for f in flows: - if len(set(f) - set(creds.keys())) == 0: + if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can # include the password in the case of registering, so only log # the keys (confusingly, clientdict may contain a password @@ -181,14 +261,16 @@ class AuthHandler(BaseHandler): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, clientdict.keys() + creds, list(clientdict) ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = creds.keys() + ret['completed'] = list(creds) ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -260,16 +342,37 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - def _check_password_auth(self, authdict, _): - if "user" not in authdict or "password" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + @defer.inlineCallbacks + def _check_auth_dict(self, authdict, clientip): + """Attempt to validate the auth dict provided by a client - user_id = authdict["user"] - password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID.create(user_id, self.hs.hostname).to_string() + Args: + authdict (object): auth dict provided by the client + clientip (str): IP address of the client + + Returns: + Deferred: result of the stage verification. + + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = authdict['type'] + checker = self.checkers.get(login_type) + if checker is not None: + res = yield checker(authdict, clientip) + defer.returnValue(res) + + # build a v1-login-style dict out of the authdict and fall back to the + # v1 code + user_id = authdict.get("user") - return self._check_password(user_id, password) + if user_id is None: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + (canonical_id, callback) = yield self.validate_login(user_id, authdict) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -303,7 +406,7 @@ class AuthHandler(BaseHandler): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = simplejson.loads(data) + resp_body = json.loads(data) if 'success' in resp_body: # Note that we do NOT check the hostname here: we explicitly @@ -324,15 +427,11 @@ class AuthHandler(BaseHandler): def _check_msisdn(self, authdict, _): return self._check_threepid('msisdn', authdict) - @defer.inlineCallbacks def _check_dummy_auth(self, authdict, _): - yield run_on_reactor() - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _check_threepid(self, medium, authdict): - yield run_on_reactor() - if 'threepid_creds' not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) @@ -398,26 +497,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -431,13 +512,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -447,9 +525,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -501,29 +581,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue((qualified_user_id, None)) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) + + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - canonical_user_id = yield self._check_local_password(user_id, password) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -549,7 +715,7 @@ class AuthHandler(BaseHandler): if not lookupres: defer.returnValue(None) (user_id, password_hash) = lookupres - result = self.validate_hash(password, password_hash) + result = yield self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) defer.returnValue(None) @@ -573,22 +739,65 @@ class AuthHandler(BaseHandler): raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): - password_hash = self.hash(newpassword) + def delete_access_token(self, access_token): + """Invalidate a single access token - except_access_token_id = requester.access_token_id if requester else None + Args: + access_token (str): access token to be deleted - try: - yield self.store.user_set_password_hash(user_id, password_hash) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) - raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + # delete pushers associated with this access token + if user_info["token_id"] is not None: + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + str(user_info["user"]), (user_info["token_id"], ) + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, ) - yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_id + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, token_id, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + # delete pushers associated with the access tokens + yield self.hs.get_pusherpool().remove_pushers_by_access_token( + user_id, (token_id for _, token_id, _ in tokens_and_devices), ) @defer.inlineCallbacks @@ -616,6 +825,15 @@ class AuthHandler(BaseHandler): if medium == 'email': address = address.lower() + identity_handler = self.hs.get_handlers().identity_handler + yield identity_handler.unbind_threepid( + user_id, + { + 'medium': medium, + 'address': address, + }, + ) + ret = yield self.store.user_delete_threepid( user_id, medium, address, ) @@ -634,10 +852,17 @@ class AuthHandler(BaseHandler): password (str): Password to hash. Returns: - Hashed password (str). + Deferred(str): Hashed password. """ - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - bcrypt.gensalt(self.bcrypt_rounds)) + def _do_hash(): + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + bcrypt.gensalt(self.bcrypt_rounds)) + + return make_deferred_yieldable( + threads.deferToThreadPool( + self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash + ), + ) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -647,20 +872,31 @@ class AuthHandler(BaseHandler): stored_hash (str): Expected hash value. Returns: - Whether self.hash(password) == stored_hash (bool). + Deferred(bool): Whether self.hash(password) == stored_hash. """ + + def _do_validate_hash(): + return bcrypt.checkpw( + password.encode('utf8') + self.hs.config.password_pepper, + stored_hash.encode('utf8') + ) + if stored_hash: - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - stored_hash.encode('utf8')) == stored_hash + return make_deferred_yieldable( + threads.deferToThreadPool( + self.hs.get_reactor(), + self.hs.get_reactor().getThreadPool(), + _do_validate_hash, + ), + ) else: - return False + return defer.succeed(False) -class MacaroonGeneartor(object): - def __init__(self, hs): - self.clock = hs.get_clock() - self.server_name = hs.config.server_name - self.macaroon_secret_key = hs.config.macaroon_secret_key +@attr.s +class MacaroonGenerator(object): + + hs = attr.ib() def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] @@ -678,7 +914,7 @@ class MacaroonGeneartor(object): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") - now = self.clock.time_msec() + now = self.hs.get_clock().time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() @@ -690,36 +926,9 @@ class MacaroonGeneartor(object): def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( - location=self.server_name, + location=self.hs.config.server_name, identifier="key", - key=self.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py new file mode 100644 index 0000000000..b3c5a9ee64 --- /dev/null +++ b/synapse/handlers/deactivate_account.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import UserID, create_requester +from synapse.util.logcontext import run_in_background + +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class DeactivateAccountHandler(BaseHandler): + """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): + super(DeactivateAccountHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + self._room_member_handler = hs.get_room_member_handler() + self._identity_handler = hs.get_handlers().identity_handler + self.user_directory_handler = hs.get_user_directory_handler() + + # Flag that indicates whether the process to part users from rooms is running + self._user_parter_running = False + + # Start the user parter loop so it can resume parting users from rooms where + # it left off (if it has work left to do). + hs.get_reactor().callWhenRunning(self._start_user_parting) + + @defer.inlineCallbacks + def deactivate_account(self, user_id, erase_data): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + erase_data (bool): whether to GDPR-erase the user's data + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + + # delete threepids first. We remove these from the IS so if this fails, + # leave the user still active so they can try again. + # Ideally we would prevent password resets and then do this in the + # background thread. + threepids = yield self.store.user_get_threepids(user_id) + for threepid in threepids: + try: + yield self._identity_handler.unbind_threepid( + user_id, + { + 'medium': threepid['medium'], + 'address': threepid['address'], + }, + ) + except Exception: + # Do we want this to be a fatal error or should we carry on? + logger.exception("Failed to remove threepid from ID server") + raise SynapseError(400, "Failed to remove threepid from ID server") + yield self.store.user_delete_threepid( + user_id, threepid['medium'], threepid['address'], + ) + + # delete any devices belonging to the user, which will also + # delete corresponding access tokens. + yield self._device_handler.delete_all_devices_for_user(user_id) + # then delete any remaining access tokens which weren't associated with + # a device. + yield self._auth_handler.delete_access_tokens_for_user(user_id) + + yield self.store.user_set_password_hash(user_id, None) + + # Add the user to a table of users pending deactivation (ie. + # removal from all the rooms they're a member of) + yield self.store.add_user_pending_deactivation(user_id) + + # delete from user directory + yield self.user_directory_handler.handle_user_deactivated(user_id) + + # Mark the user as erased, if they asked for that + if erase_data: + logger.info("Marking %s as erased", user_id) + yield self.store.mark_user_erased(user_id) + + # Now start the process that goes through that list and + # parts users from rooms (if it isn't already running) + self._start_user_parting() + + def _start_user_parting(self): + """ + Start the process that goes through the table of users + pending deactivation, if it isn't already running. + + Returns: + None + """ + if not self._user_parter_running: + run_in_background(self._user_parter_loop) + + @defer.inlineCallbacks + def _user_parter_loop(self): + """Loop that parts deactivated users from rooms + + Returns: + None + """ + self._user_parter_running = True + logger.info("Starting user parter") + try: + while True: + user_id = yield self.store.get_user_pending_deactivation() + if user_id is None: + break + logger.info("User parter parting %r", user_id) + yield self._part_user(user_id) + yield self.store.del_user_pending_deactivation(user_id) + logger.info("User parter finished parting %r", user_id) + logger.info("User parter finished: stopping") + finally: + self._user_parter_running = False + + @defer.inlineCallbacks + def _part_user(self, user_id): + """Causes the given user_id to leave all the rooms they're joined to + + Returns: + None + """ + user = UserID.from_string(user_id) + + rooms_for_user = yield self.store.get_rooms_for_user(user_id) + for room_id in rooms_for_user: + logger.info("User parter parting %r from %r", user_id, room_id) + try: + yield self._room_member_handler.update_membership( + create_requester(user), + user, + room_id, + "leave", + ratelimit=False, + ) + except Exception: + logger.exception( + "Failed to part user %r from room %r: ignoring and continuing", + user_id, room_id, + ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ed60d494ff..2d44f15da3 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -12,18 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from six import iteritems, itervalues + +from twisted.internet import defer + from synapse.api import errors from synapse.api.constants import EventTypes +from synapse.api.errors import FederationDeniedError +from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils from synapse.util.async import Linearizer from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.retryutils import NotRetryingDestination from synapse.util.metrics import measure_func -from synapse.types import get_domain_from_id, RoomStreamToken -from twisted.internet import defer -from ._base import BaseHandler +from synapse.util.retryutils import NotRetryingDestination -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -34,15 +39,17 @@ class DeviceHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() - self.federation = hs.get_replication_layer() self._edu_updater = DeviceListEduUpdater(hs, self) - self.federation.register_edu_handler( + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( "m.device_list_update", self._edu_updater.incoming_device_list_update, ) - self.federation.register_query_handler( + federation_registry.register_query_handler( "user_devices", self.on_federation_query_user_devices, ) @@ -109,7 +116,7 @@ class DeviceHandler(BaseHandler): user_id, device_id=None ) - devices = device_map.values() + devices = list(device_map.values()) for device in devices: _update_device_from_client_ips(device, ips) @@ -152,16 +159,15 @@ class DeviceHandler(BaseHandler): try: yield self.store.delete_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: # no match pass else: raise - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( @@ -171,12 +177,30 @@ class DeviceHandler(BaseHandler): yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks + def delete_all_devices_for_user(self, user_id, except_device_id=None): + """Delete all of the user's devices + + Args: + user_id (str): + except_device_id (str|None): optional device id which should not + be deleted + + Returns: + defer.Deferred: + """ + device_map = yield self.store.get_devices_by_user(user_id) + device_ids = list(device_map) + if except_device_id is not None: + device_ids = [d for d in device_ids if d != except_device_id] + yield self.delete_devices(user_id, device_ids) + + @defer.inlineCallbacks def delete_devices(self, user_id, device_ids): """ Delete several devices Args: user_id (str): - device_ids (str): The list of device IDs to delete + device_ids (List[str]): The list of device IDs to delete Returns: defer.Deferred: @@ -184,7 +208,7 @@ class DeviceHandler(BaseHandler): try: yield self.store.delete_devices(user_id, device_ids) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: # no match pass @@ -194,9 +218,8 @@ class DeviceHandler(BaseHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self.store.user_delete_access_tokens( + yield self._auth_handler.delete_access_tokens_for_user( user_id, device_id=device_id, - delete_refresh_tokens=True, ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -224,7 +247,7 @@ class DeviceHandler(BaseHandler): new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() else: @@ -270,6 +293,8 @@ class DeviceHandler(BaseHandler): user_id (str) from_token (StreamToken) """ + now_token = yield self.hs.get_event_sources().get_current_token() + room_ids = yield self.store.get_rooms_for_user(user_id) # First we check if any devices have changed @@ -280,11 +305,30 @@ class DeviceHandler(BaseHandler): # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + member_events = yield self.store.get_membership_changes_for_user( + user_id, from_token.room_key, now_token.room_key + ) + rooms_changed.update(event.room_id for event in member_events) + stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key).stream + from_token.room_key + ).stream possibly_changed = set(changed) + possibly_left = set() for room_id in rooms_changed: + current_state_ids = yield self.store.get_current_state_ids(room_id) + + # The user may have left the room + # TODO: Check if they actually did or if we were just invited. + if room_id not in room_ids: + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_left.add(state_key) + continue + # Fetch the current state at the time. try: event_ids = yield self.store.get_forward_extremeties_for_room( @@ -295,44 +339,69 @@ class DeviceHandler(BaseHandler): # ordering: treat it the same as a new room event_ids = [] - current_state_ids = yield self.store.get_current_state_ids(room_id) - # special-case for an empty prev state: include all members # in the changed list if not event_ids: - for key, event_id in current_state_ids.iteritems(): + for key, event_id in iteritems(current_state_ids): etype, state_key = key if etype != EventTypes.Member: continue possibly_changed.add(state_key) continue + current_member_id = current_state_ids.get((EventTypes.Member, user_id)) + if not current_member_id: + continue + # mapping from event_id -> state_dict prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + # Check if we've joined the room? If so we just blindly add all the users to + # the "possibly changed" users. + for state_dict in itervalues(prev_state_ids): + member_event = state_dict.get((EventTypes.Member, user_id), None) + if not member_event or member_event != current_member_id: + for key, event_id in iteritems(current_state_ids): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_changed.add(state_key) + break + # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. - for key, event_id in current_state_ids.iteritems(): + for key, event_id in iteritems(current_state_ids): etype, state_key = key if etype != EventTypes.Member: continue # check if this member has changed since any of the extremities # at the stream_ordering, and add them to the list if so. - for state_dict in prev_state_ids.values(): + for state_dict in itervalues(prev_state_ids): prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: - possibly_changed.add(state_key) + if state_key != user_id: + possibly_changed.add(state_key) break - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id - ) + if possibly_changed or possibly_left: + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) - # Take the intersection of the users whose devices may have changed - # and those that actually still share a room with the user - defer.returnValue(users_who_share_room & possibly_changed) + # Take the intersection of the users whose devices may have changed + # and those that actually still share a room with the user + possibly_joined = possibly_changed & users_who_share_room + possibly_left = (possibly_changed | possibly_left) - users_who_share_room + else: + possibly_joined = [] + possibly_left = [] + + defer.returnValue({ + "changed": list(possibly_joined), + "left": list(possibly_left), + }) @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): @@ -366,7 +435,7 @@ class DeviceListEduUpdater(object): def __init__(self, hs, device_handler): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler @@ -450,6 +519,9 @@ class DeviceListEduUpdater(object): # This makes it more likely that the device lists will # eventually become consistent. return + except FederationDeniedError as e: + logger.info(e) + return except Exception: # TODO: Remember that we are now out of sync and try again # later @@ -467,7 +539,7 @@ class DeviceListEduUpdater(object): yield self.device_handler.notify_device_update(user_id, device_ids) else: # Simply update the single device, since we know that is the only - # change (becuase of the single prev_id matching the current cache) + # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index f7fad15c62..2e2e5261de 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -17,10 +17,10 @@ import logging from twisted.internet import defer -from synapse.types import get_domain_from_id +from synapse.api.errors import SynapseError +from synapse.types import UserID, get_domain_from_id from synapse.util.stringutils import random_string - logger = logging.getLogger(__name__) @@ -33,10 +33,10 @@ class DeviceMessageHandler(object): """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) @@ -52,6 +52,12 @@ class DeviceMessageHandler(object): message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + messages_by_device = { device_id: { "content": message_content, @@ -77,7 +83,8 @@ class DeviceMessageHandler(object): local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): messages_by_device = { device_id: { "content": message_content, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 943554ce98..ef866da1b6 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,15 +14,16 @@ # limitations under the License. +import logging +import string + from twisted.internet import defer -from ._base import BaseHandler -from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError from synapse.types import RoomAlias, UserID, get_domain_from_id -import logging -import string +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -34,12 +35,15 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() + self.event_creation_handler = hs.get_event_creation_handler() - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "directory", self.on_directory_query ) + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def _create_association(self, room_alias, room_id, servers=None, creator=None): # general association creation for both human users and app services @@ -73,6 +77,11 @@ class DirectoryHandler(BaseHandler): # association creation for human users # TODO(erikj): Do user auth. + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + raise SynapseError( + 403, "This user is not permitted to create this alias", + ) + can_create = yield self.can_modify_alias( room_alias, user_id=user_id @@ -242,8 +251,7 @@ class DirectoryHandler(BaseHandler): def send_room_alias_update_event(self, requester, user_id, room_id): aliases = yield self.store.get_aliases_for_room(room_id) - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Aliases, @@ -265,8 +273,7 @@ class DirectoryHandler(BaseHandler): if not alias_event or alias_event.content.get("alias", "") != alias_str: return - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler): room_id (str) visibility (str): "public" or "private" """ + if not self.spam_checker.user_may_publish_room( + requester.user.to_string(), room_id + ): + raise AuthError( + 403, + "This user is not permitted to publish rooms to the room list" + ) + if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 668a90e495..5816bf8b4f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ujson as json import logging -from canonicaljson import encode_canonical_json +from six import iteritems + +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer -from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable +from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError +from synapse.types import UserID, get_domain_from_id +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -30,15 +33,15 @@ logger = logging.getLogger(__name__) class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - self.federation.register_query_handler( + hs.get_federation_registry().register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -70,12 +73,13 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids else: remote_queries[user_id] = device_ids - # Firt get local devices. + # First get local devices. failures = {} results = {} if local_query: @@ -88,7 +92,7 @@ class E2eKeysHandler(object): remote_queries_not_in_cache = {} if remote_queries: query_list = [] - for user_id, device_ids in remote_queries.iteritems(): + for user_id, device_ids in iteritems(remote_queries): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) else: @@ -99,9 +103,9 @@ class E2eKeysHandler(object): query_list ) ) - for user_id, devices in remote_results.iteritems(): + for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) - for device_id, device in devices.iteritems(): + for device_id, device in iteritems(devices): keys = device.get("keys", None) device_display_name = device.get("device_display_name", None) if keys: @@ -131,24 +135,13 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(do_remote_query)(destination) + run_in_background(do_remote_query, destination) for destination in remote_queries_not_in_cache - ])) + ], consumeErrors=True)) defer.returnValue({ "device_keys": results, "failures": failures, @@ -170,7 +163,8 @@ class E2eKeysHandler(object): result_dict = {} for user_id, device_ids in query.items(): - if not self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") @@ -213,7 +207,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: @@ -243,32 +238,21 @@ class E2eKeysHandler(object): for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(claim_client_keys)(destination) + run_in_background(claim_client_keys, destination) for destination in remote_queries - ])) + ], consumeErrors=True)) logger.info( "Claimed one-time-keys: %s", ",".join(( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in json_result.iteritems() - for device_id, device_keys in user_keys.iteritems() - for key_id, _ in device_keys.iteritems() + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) )), ) @@ -353,6 +337,31 @@ class E2eKeysHandler(object): ) +def _exception_to_failure(e): + if isinstance(e, CodeMessageException): + return { + "status": e.code, "message": e.message, + } + + if isinstance(e, NotRetryingDestination): + return { + "status": 503, "message": "Not ready for retry", + } + + if isinstance(e, FederationDeniedError): + return { + "status": 403, "message": "Federation Denied", + } + + # include ConnectionRefused and other errors + # + # Note that some Exceptions (notably twisted's ResponseFailed etc) don't + # give a string for e.message, which json then fails to serialize. + return { + "status": 503, "message": str(e.message), + } + + def _one_time_keys_match(old_key_json, new_key): old_key = json.loads(old_key_json) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d3685fb12a..c3f2d7feff 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import random + from twisted.internet import defer -from synapse.util.logutils import log_function -from synapse.types import UserID -from synapse.events.utils import serialize_event -from synapse.api.constants import Membership, EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase +from synapse.events.utils import serialize_event +from synapse.types import UserID +from synapse.util.logutils import log_function from ._base import BaseHandler -import logging -import random - - logger = logging.getLogger(__name__) @@ -48,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() + self._server_notices_sender = hs.get_server_notices_sender() @defer.inlineCallbacks @log_function @@ -58,6 +58,10 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ + + # send any outstanding server notices to the user. + yield self._server_notices_sender.on_user_syncing(auth_user_id) + auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 483cb8eac6..145c1a21d4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,41 +15,46 @@ # limitations under the License. """Contains handlers for federation events.""" -import synapse.util.logcontext + +import itertools +import logging +import sys + +import six +from six import iteritems, itervalues +from six.moves import http_client, zip + from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 -from ._base import BaseHandler +from twisted.internet import defer -from synapse.api.errors import ( - AuthError, FederationError, StoreError, CodeMessageException, SynapseError, -) from synapse.api.constants import EventTypes, Membership, RejectedReason -from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.logcontext import ( - preserve_fn, preserve_context_over_deferred +from synapse.api.errors import ( + AuthError, + CodeMessageException, + FederationDeniedError, + FederationError, + StoreError, + SynapseError, ) -from synapse.util.metrics import measure_func -from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor, Linearizer -from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( - compute_event_signature, add_hashes_and_signatures, + add_hashes_and_signatures, + compute_event_signature, ) +from synapse.events.validator import EventValidator +from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id - -from synapse.events.utils import prune_event - -from synapse.util.retryutils import NotRetryingDestination - +from synapse.util import logcontext, unwrapFirstError +from synapse.util.async import Linearizer from synapse.util.distributor import user_joined_room +from synapse.util.frozenutils import unfreeze +from synapse.util.logutils import log_function +from synapse.util.retryutils import NotRetryingDestination +from synapse.visibility import filter_events_for_server -from twisted.internet import defer - -import itertools -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -70,14 +76,16 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() - self.replication_layer = hs.get_replication_layer() + self.replication_layer = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id - - self.replication_layer.set_handler(self) + self.pusher_pool = hs.get_pusherpool() + self.spam_checker = hs.get_spam_checker() + self.event_creation_handler = hs.get_event_creation_handler() + self._server_notices_mxid = hs.config.server_notices_mxid # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -85,7 +93,9 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_receive_pdu(self, origin, pdu, get_missing=True): + def on_receive_pdu( + self, origin, pdu, get_missing=True, sent_to_us_directly=False, + ): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -99,8 +109,10 @@ class FederationHandler(BaseHandler): """ # We reprocess pdus when we have seen them only as outliers - existing = yield self.get_persisted_pdu( - origin, pdu.event_id, do_auth=False + existing = yield self.store.get_event( + pdu.event_id, + allow_none=True, + allow_rejected=True, ) # FIXME: Currently we fetch an event again when we already have it @@ -116,6 +128,19 @@ class FederationHandler(BaseHandler): logger.debug("Already seen pdu %s", pdu.event_id) return + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + raise FederationError( + "ERROR", + err.code, + err.msg, + affected=pdu.event_id, + ) + # If we are currently in the process of joining this room, then we # queue up events for later processing. if pdu.room_id in self.room_queues: @@ -124,15 +149,30 @@ class FederationHandler(BaseHandler): self.room_queues[pdu.room_id].append((pdu, origin)) return - state = None - - auth_chain = [] - - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] + # If we're no longer in the room just ditch the event entirely. This + # is probably an old server that has come back and thinks we're still + # in the room (or we've been rejoined to the room by a state reset). + # + # If we were never in the room then maybe our database got vaped and + # we should check if we *are* in fact in the room. If we are then we + # can magically rejoin the room. + is_in_room = yield self.auth.check_host_in_room( + pdu.room_id, + self.server_name ) + if not is_in_room: + was_in_room = yield self.store.was_host_joined( + pdu.room_id, self.server_name, + ) + if was_in_room: + logger.info( + "Ignoring PDU %s for room %s from %s as we've left the room!", + pdu.event_id, pdu.room_id, origin, + ) + defer.returnValue(None) - fetch_state = False + state = None + auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): @@ -147,7 +187,7 @@ class FederationHandler(BaseHandler): ) prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) + seen = yield self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -175,8 +215,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - have_seen = yield self.store.have_events(prevs) - seen = set(have_seen.iterkeys()) + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -189,26 +228,60 @@ class FederationHandler(BaseHandler): list(prevs - seen)[:5], ) - if prevs - seen: - logger.info( - "Still missing %d events for room %r: %r...", - len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + if sent_to_us_directly and prevs - seen: + # If they have sent it to us directly, and the server + # isn't telling us about the auth events that it's + # made a message referencing, we explode + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, ) - fetch_state = True + elif prevs - seen: + # Calculate the state of the previous events, and + # de-conflict them to find the current state. + state_groups = [] + auth_chains = set() + try: + # Get the state of the events we know about + ours = yield self.store.get_state_groups(pdu.room_id, list(seen)) + state_groups.append(ours) + + # Ask the remote server for the states we don't + # know about + for p in prevs - seen: + state, got_auth_chain = ( + yield self.replication_layer.get_state_for_room( + origin, pdu.room_id, p + ) + ) + auth_chains.update(got_auth_chain) + state_group = {(x.type, x.state_key): x.event_id for x in state} + state_groups.append(state_group) + + # Resolve any conflicting state + def fetch(ev_ids): + return self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False + ) - if fetch_state: - # We need to get the state at this event, since we haven't - # processed all the prev events. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - try: - state, auth_chain = yield self.replication_layer.get_state_for_room( - origin, pdu.room_id, pdu.event_id, - ) - except: - logger.exception("Failed to get state for event: %s", pdu.event_id) + state_map = yield resolve_events_with_factory( + state_groups, {pdu.event_id: pdu}, fetch + ) + + state = (yield self.store.get_events(state_map.values())).values() + auth_chain = list(auth_chains) + except Exception: + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=pdu.event_id, + ) yield self._process_received_pdu( origin, @@ -227,8 +300,7 @@ class FederationHandler(BaseHandler): min_depth (int): Minimum depth of events to return. """ # We recalculate seen, since it may have changed. - have_seen = yield self.store.have_events(prevs) - seen = set(have_seen.keys()) + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: return @@ -287,11 +359,17 @@ class FederationHandler(BaseHandler): for e in missing_events: logger.info("Handling found event %s", e.event_id) - yield self.on_receive_pdu( - origin, - e, - get_missing=False - ) + try: + yield self.on_receive_pdu( + origin, + e, + get_missing=False + ) + except FederationError as e: + if e.code == 403: + logger.warn("Event %s failed history check.") + else: + raise @log_function @defer.inlineCallbacks @@ -340,9 +418,7 @@ class FederationHandler(BaseHandler): if auth_chain: event_ids |= {e.event_id for e in auth_chain} - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + seen_ids = yield self.store.have_seen_events(event_ids) if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication @@ -410,7 +486,10 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.prev_state_ids.get( + + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_state_id = prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -424,91 +503,21 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - @measure_func("_filter_events_for_server") - @defer.inlineCallbacks - def _filter_events_for_server(self, server_name, room_id, events): - event_to_state_ids = yield self.store.get_state_ids_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ) - ) - - # We only want to pull out member events that correspond to the - # server's domain. - - def check_match(id): - try: - return server_name == get_domain_from_id(id) - except: - return False - - # Parses mapping `event_id -> (type, state_key) -> state event_id` - # to get all state ids that we're interested in. - event_map = yield self.store.get_events([ - e_id - for key_to_eid in event_to_state_ids.values() - for key, e_id in key_to_eid.items() - if key[0] != EventTypes.Member or check_match(key[1]) - ]) - - event_to_state = { - e_id: { - key: event_map[inner_e_id] - for key, inner_e_id in key_to_eid.items() - if inner_e_id in event_map - } - for e_id, key_to_eid in event_to_state_ids.items() - } - - def redact_disallowed(event, state): - if not state: - return event - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - if visibility in ["invited", "joined"]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.values(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except: - continue - - if domain != server_name: - continue - - memtype = ev.membership - if memtype == Membership.JOIN: - return event - elif memtype == Membership.INVITE: - if visibility == "invited": - return event - else: - return prune_event(event) - - return event - - defer.returnValue([ - redact_disallowed(e, event_to_state[e.event_id]) - for e in events - ]) - @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` - This will attempt to get more events from the remote. This may return - be successfull and still return no events if the other side has no new - events to offer. + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) """ if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") @@ -520,6 +529,16 @@ class FederationHandler(BaseHandler): extremities=extremities, ) + # ideally we'd sanity check the events here for excess prev_events etc, + # but it's hard to reject events at this point without completely + # breaking backfill in the same way that it is currently broken by + # events whose signature we cannot verify (#3121). + # + # So for now we accept the events anyway. #3124 tracks this. + # + # for ev in events: + # self._sanity_check_event(ev) + # Don't bother processing events we already have. seen_events = yield self.store.have_events_in_timeline( set(e.event_id for e in events) @@ -590,9 +609,10 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.replication_layer.get_pdu)( + logcontext.run_in_background( + self.replication_layer.get_pdu, [dest], event_id, outlier=True, @@ -612,7 +632,7 @@ class FederationHandler(BaseHandler): failed_to_fetch = missing_auth - set(auth_events) - seen_events = yield self.store.have_events( + seen_events = yield self.store.have_seen_events( set(auth_events.keys()) | set(state_events.keys()) ) @@ -702,9 +722,19 @@ class FederationHandler(BaseHandler): curr_state = yield self.state_handler.get_current_state(room_id) def get_domains_from_state(state): + """Get joined domains from state + + Args: + state (dict[tuple, FrozenEvent]): State map from type/state + key to event. + + Returns: + list[tuple[str, int]]: Returns a list of servers with the + lowest depth of their joins. Sorted by lowest depth first. + """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() + for (e_type, state_key), event in iteritems(state) if e_type == EventTypes.Member and event.membership == Membership.JOIN ] @@ -718,7 +748,7 @@ class FederationHandler(BaseHandler): joined_domains[dom] = min(d, old_d) else: joined_domains[dom] = d - except: + except Exception: pass return sorted(joined_domains.items(), key=lambda d: d[1]) @@ -738,7 +768,7 @@ class FederationHandler(BaseHandler): yield self.backfill( dom, room_id, limit=100, - extremities=[e for e in extremities.keys()] + extremities=extremities, ) # If this succeeded then we probably already have the # appropriate stuff. @@ -762,6 +792,9 @@ class FederationHandler(BaseHandler): except NotRetryingDestination as e: logger.info(e.message) continue + except FederationDeniedError as e: + logger.info(e) + continue except Exception as e: logger.exception( "Failed to backfill from %s because %s", @@ -784,38 +817,76 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") - states = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) - for e in event_ids - ])) + resolve = logcontext.preserve_fn( + self.state_handler.resolve_state_groups_for_events + ) + states = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], + consumeErrors=True, + )) + + # dict[str, dict[tuple, str]], a map from event_id to state map of + # event_ids. states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( - [e_id for ids in states.values() for e_id in ids], + [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False ) states = { key: { k: state_map[e_id] - for k, e_id in state_dict.items() + for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in states.items() + } for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) success = yield try_backfill([ - dom for dom in likely_domains + dom for dom, _ in likely_domains if dom not in tried_domains ]) if success: defer.returnValue(True) - tried_domains.update(likely_domains) + tried_domains.update(dom for dom, _ in likely_domains) defer.returnValue(False) + def _sanity_check_event(self, ev): + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev (synapse.events.EventBase): event to be checked + + Returns: None + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_events) > 20: + logger.warn("Rejecting event %s which has %i prev_events", + ev.event_id, len(ev.prev_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many prev_events", + ) + + if len(ev.auth_events) > 10: + logger.warn("Rejecting event %s which has %i auth_events", + ev.event_id, len(ev.auth_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many auth_events", + ) + @defer.inlineCallbacks def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -838,16 +909,6 @@ class FederationHandler(BaseHandler): [auth_id for auth_id, _ in event.auth_events], include_given=True ) - - for event in auth: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) - defer.returnValue([e for e in auth]) @log_function @@ -916,7 +977,7 @@ class FederationHandler(BaseHandler): room_creator_user_id="", is_public=False ) - except: + except Exception: # FIXME pass @@ -940,9 +1001,7 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( - room_queue - ) + logcontext.run_in_background(self._handle_queued_pdus, room_queue) defer.returnValue(True) @@ -982,8 +1041,7 @@ class FederationHandler(BaseHandler): }) try: - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) except AuthError as e: @@ -1051,13 +1109,15 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - state_ids = context.prev_state_ids.values() + prev_state_ids = yield context.get_prev_state_ids(self.store) + + state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) - state = yield self.store.get_events(context.prev_state_ids.values()) + state = yield self.store.get_events(list(prev_state_ids.values())) defer.returnValue({ - "state": state.values(), + "state": list(state.values()), "auth_chain": auth_chain, }) @@ -1069,10 +1129,23 @@ class FederationHandler(BaseHandler): """ event = pdu + if event.state_key is None: + raise SynapseError(400, "The invite event did not have a state key") + is_blocked = yield self.store.is_room_blocked(event.room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") + if self.hs.config.block_non_admin_invites: + raise SynapseError(403, "This server does not accept room invites") + + if not self.spam_checker.user_may_invite( + event.sender, event.state_key, event.room_id, + ): + raise SynapseError( + 403, "This user is not permitted to send invites to this server/user" + ) + membership = event.content.get("membership") if event.type != EventTypes.Member or membership != Membership.INVITE: raise SynapseError(400, "The event was not an m.room.member invite event") @@ -1081,12 +1154,16 @@ class FederationHandler(BaseHandler): if sender_domain != origin: raise SynapseError(400, "The invite event was not from the server sending it") - if event.state_key is None: - raise SynapseError(400, "The invite event did not have a state key") - if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") + # block any attempts to invite the server notices mxid + if event.state_key == self._server_notices_mxid: + raise SynapseError( + http_client.FORBIDDEN, + "Cannot invite this user", + ) + event.internal_metadata.outlier = True event.internal_metadata.invite_from_remote = True @@ -1213,8 +1290,7 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) @@ -1268,14 +1344,12 @@ class FederationHandler(BaseHandler): def get_state_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ - yield run_on_reactor() - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) if state_groups: - _, state = state_groups.items().pop() + _, state = list(iteritems(state_groups)).pop() results = { (e.type, e.state_key): e for e in state } @@ -1291,19 +1365,7 @@ class FederationHandler(BaseHandler): else: del results[(event.type, event.state_key)] - res = results.values() - for event in res: - # We sign these again because there was a bug where we - # incorrectly signed things the first time round - if self.is_mine_id(event.event_id): - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) - + res = list(results.values()) defer.returnValue(res) else: defer.returnValue([]) @@ -1312,8 +1374,6 @@ class FederationHandler(BaseHandler): def get_state_ids_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ - yield run_on_reactor() - state_groups = yield self.store.get_state_groups_ids( room_id, [event_id] ) @@ -1332,7 +1392,7 @@ class FederationHandler(BaseHandler): else: results.pop((event.type, event.state_key), None) - defer.returnValue(results.values()) + defer.returnValue(list(results.values())) else: defer.returnValue([]) @@ -1349,17 +1409,26 @@ class FederationHandler(BaseHandler): limit ) - events = yield self._filter_events_for_server(origin, room_id, events) + events = yield filter_events_for_server(self.store, origin, events) defer.returnValue(events) @defer.inlineCallbacks @log_function - def get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. + def get_persisted_pdu(self, origin, event_id): + """Get an event from the database for the given server. + + Args: + origin [str]: hostname of server which is requesting the event; we + will check that the server is allowed to see it. + event_id [str]: id of the event being requested Returns: - Deferred: Results in a `Pdu`. + Deferred[EventBase|None]: None if we know nothing about the event; + otherwise the (possibly-redacted) event. + + Raises: + AuthError if the server is not currently in the room """ event = yield self.store.get_event( event_id, @@ -1368,32 +1437,17 @@ class FederationHandler(BaseHandler): ) if event: - if self.is_mine_id(event.event_id): - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) - - if do_auth: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) - if not in_room: - raise AuthError(403, "Host not in room.") - - events = yield self._filter_events_for_server( - origin, event.room_id, [event] - ) - - event = events[0] + in_room = yield self.auth.check_host_in_room( + event.room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + events = yield filter_events_for_server( + self.store, origin, [event], + ) + event = events[0] defer.returnValue(event) else: defer.returnValue(None) @@ -1412,22 +1466,33 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - if not event.internal_metadata.is_outlier(): - yield self.action_generator.handle_push_actions_for_event( - event, context + try: + if not event.internal_metadata.is_outlier() and not backfilled: + yield self.action_generator.handle_push_actions_for_event( + event, context + ) + + event_stream_id, max_stream_id = yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, ) + except: # noqa: E722, as we reraise the exception this is fine. + tp, value, tb = sys.exc_info() - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - ) + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) if not backfilled: # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id + logcontext.run_in_background( + self.pusher_pool.on_new_notifications, + event_stream_id, max_stream_id, ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1439,22 +1504,23 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield preserve_context_over_deferred(defer.gatherResults( + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self._prep_event)( + logcontext.run_in_background( + self._prep_event, origin, ev_info["event"], state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) for ev_info in event_infos - ] + ], consumeErrors=True, )) yield self.store.persist_events( [ (ev_info["event"], context) - for ev_info, context in itertools.izip(event_infos, contexts) + for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, ) @@ -1574,8 +1640,9 @@ class FederationHandler(BaseHandler): ) if not auth_events: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1605,7 +1672,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR - if event.type == EventTypes.GuestAccess: + if event.type == EventTypes.GuestAccess and not context.rejected: yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1635,15 +1702,6 @@ class FederationHandler(BaseHandler): local_auth_chain, remote_auth_chain ) - for event in ret["auth_chain"]: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) - logger.debug("on_query_auth returning: %s", ret) defer.returnValue(ret) @@ -1669,11 +1727,26 @@ class FederationHandler(BaseHandler): min_depth=min_depth, ) + missing_events = yield filter_events_for_server( + self.store, origin, missing_events, + ) + defer.returnValue(missing_events) @defer.inlineCallbacks @log_function def do_auth(self, origin, event, context, auth_events): + """ + + Args: + origin (str): + event (synapse.events.FrozenEvent): + context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->str]): + + Returns: + defer.Deferred[None] + """ # Check if we have all the auth events. current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) @@ -1684,7 +1757,8 @@ class FederationHandler(BaseHandler): event_key = None if event_auth_events - current_state: - have_events = yield self.store.have_events( + # TODO: can we use store.have_seen_events here instead? + have_events = yield self.store.get_seen_events_with_rejections( event_auth_events - current_state ) else: @@ -1707,12 +1781,12 @@ class FederationHandler(BaseHandler): origin, event.room_id, event.event_id ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] ) for e in remote_auth_chain: - if e.event_id in seen_remotes.keys(): + if e.event_id in seen_remotes: continue if e.event_id == event.event_id: @@ -1739,11 +1813,11 @@ class FederationHandler(BaseHandler): except AuthError: pass - have_events = yield self.store.have_events( + have_events = yield self.store.get_seen_events_with_rejections( [e_id for e_id, _ in event.auth_events] ) seen_events = set(have_events.keys()) - except: + except Exception: # FIXME: logger.exception("Failed to get auth chain") @@ -1756,18 +1830,18 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self.store.get_event)( + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False, ) for d in different_auth if d in have_events and not have_events[d] - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + ], consumeErrors=True) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) @@ -1777,7 +1851,7 @@ class FederationHandler(BaseHandler): }) new_state = self.state_handler.resolve_events( - [local_view.values(), remote_view.values()], + [list(local_view.values()), list(remote_view.values())], event ) @@ -1786,16 +1860,9 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1815,9 +1882,10 @@ class FederationHandler(BaseHandler): break if do_resolution: + prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids + event, prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain( auth_ids, include_given=True @@ -1832,13 +1900,13 @@ class FederationHandler(BaseHandler): local_auth_chain, ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in result["auth_chain"]] ) # 3. Process any remote auth chain events we haven't seen. for ev in result["auth_chain"]: - if ev.event_id in seen_remotes.keys(): + if ev.event_id in seen_remotes: continue if ev.event_id == event.event_id: @@ -1868,23 +1936,16 @@ class FederationHandler(BaseHandler): except AuthError: pass - except: + except Exception: # FIXME: logger.exception("Failed to query auth chain") # 4. Look at rejects and their proofs. # TODO. - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - if k != event_key - }) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ - k: a.event_id for k, a in auth_events.items() - }) - context.state_group = self.store.get_next_state_group() + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) try: self.auth.check(event, auth_events=auth_events) @@ -1893,6 +1954,58 @@ class FederationHandler(BaseHandler): raise e @defer.inlineCallbacks + def _update_context_for_auth_events(self, event, context, auth_events, + event_key): + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event (Event): The event we're handling the context for + + context (synapse.events.snapshot.EventContext): event context + to be updated + + auth_events (dict[(str, str)->str]): Events to update in the event + context. + + event_key ((str, str)): (type, state_key) for the current event. + this will not be included in the current_state in the context. + """ + state_updates = { + k: a.event_id for k, a in iteritems(auth_events) + if k != event_key + } + current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = dict(current_state_ids) + + current_state_ids.update(state_updates) + + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({ + k: a.event_id for k, a in iteritems(auth_events) + }) + + # create a new state group as a delta from the existing one. + prev_group = context.state_group + state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=state_updates, + current_state_ids=current_state_ids, + ) + + yield context.update_state( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=state_updates, + ) + + @defer.inlineCallbacks def construct_auth_difference(self, local_auth, remote_auth): """ Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth @@ -1934,8 +2047,8 @@ class FederationHandler(BaseHandler): def get_next(it, opt=None): try: - return it.next() - except: + return next(it) + except Exception: return opt current_local = get_next(local_iter) @@ -2060,8 +2173,7 @@ class FederationHandler(BaseHandler): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder ) @@ -2076,7 +2188,7 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) - member_handler = self.hs.get_handlers().room_member_handler + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) @@ -2089,10 +2201,17 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): + """Handle an exchange_third_party_invite request from a remote server + + The remote server will call this when it wants to turn a 3pid invite + into a normal m.room.member invite. + + Returns: + Deferred: resolves (to None) + """ builder = self.event_builder_factory.new(event_dict) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder=builder, ) @@ -2107,10 +2226,13 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) + # XXX we send the invite here, but send_membership_event also sends it, + # so we end up making two requests. I think this is redundant. returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) - member_handler = self.hs.get_handlers().room_member_handler + + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks @@ -2120,7 +2242,8 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.prev_state_ids.get(key) + prev_state_ids = yield context.get_prev_state_ids(self.store) + original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -2139,8 +2262,9 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - message_handler = self.hs.get_handlers().message_handler - event, context = yield message_handler._create_new_client_event(builder=builder) + event, context = yield self.event_creation_handler.create_new_client_event( + builder=builder, + ) defer.returnValue((event, context)) @defer.inlineCallbacks @@ -2161,7 +2285,8 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + invite_event_id = prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py new file mode 100644 index 0000000000..53e5e2648b --- /dev/null +++ b/synapse/handlers/groups_local.py @@ -0,0 +1,473 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.types import get_domain_from_id + +logger = logging.getLogger(__name__) + + +def _create_rerouter(func_name): + """Returns a function that looks at the group id and calls the function + on federation or the local group server if the group is local + """ + def f(self, group_id, *args, **kwargs): + if self.is_mine_id(group_id): + return getattr(self.groups_server_handler, func_name)( + group_id, *args, **kwargs + ) + else: + destination = get_domain_from_id(group_id) + return getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + return f + + +class GroupsLocalHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.room_list_handler = hs.get_room_list_handler() + self.groups_server_handler = hs.get_groups_server_handler() + self.transport_client = hs.get_federation_transport_client() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.is_mine_id = hs.is_mine_id + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + self.notifier = hs.get_notifier() + self.attestations = hs.get_groups_attestation_signing() + + self.profile_handler = hs.get_profile_handler() + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + # The following functions merely route the query to the local groups server + # or federation depending on if the group is local or remote + + get_group_profile = _create_rerouter("get_group_profile") + update_group_profile = _create_rerouter("update_group_profile") + get_rooms_in_group = _create_rerouter("get_rooms_in_group") + + get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") + + add_room_to_group = _create_rerouter("add_room_to_group") + update_room_in_group = _create_rerouter("update_room_in_group") + remove_room_from_group = _create_rerouter("remove_room_from_group") + + update_group_summary_room = _create_rerouter("update_group_summary_room") + delete_group_summary_room = _create_rerouter("delete_group_summary_room") + + update_group_category = _create_rerouter("update_group_category") + delete_group_category = _create_rerouter("delete_group_category") + get_group_category = _create_rerouter("get_group_category") + get_group_categories = _create_rerouter("get_group_categories") + + update_group_summary_user = _create_rerouter("update_group_summary_user") + delete_group_summary_user = _create_rerouter("delete_group_summary_user") + + update_group_role = _create_rerouter("update_group_role") + delete_group_role = _create_rerouter("delete_group_role") + get_group_role = _create_rerouter("get_group_role") + get_group_roles = _create_rerouter("get_group_roles") + + set_group_join_policy = _create_rerouter("set_group_join_policy") + + @defer.inlineCallbacks + def get_group_summary(self, group_id, requester_user_id): + """Get the group summary for a group. + + If the group is remote we check that the users have valid attestations. + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_group_summary( + group_id, requester_user_id + ) + else: + res = yield self.transport_client.get_group_summary( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + group_server_name = get_domain_from_id(group_id) + + # Loop through the users and validate the attestations. + chunk = res["users_section"]["users"] + valid_users = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_users.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["users_section"]["users"] = valid_users + + res["users_section"]["users"].sort(key=lambda e: e.get("order", 0)) + res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0)) + + # Add `is_publicised` flag to indicate whether the user has publicised their + # membership of the group on their profile + result = yield self.store.get_publicised_groups_for_user(requester_user_id) + is_publicised = group_id in result + + res.setdefault("user", {})["is_publicised"] = is_publicised + + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, content): + """Create a group + """ + + logger.info("Asking to create group with ID: %r", group_id) + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.create_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + content["user_profile"] = yield self.profile_handler.get_profile(user_id) + + res = yield self.transport_client.create_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + is_publicised = content.get("publicise", False) + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=True, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_users_in_group(self, group_id, requester_user_id): + """Get users in a group + """ + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.get_users_in_group( + group_id, requester_user_id + ) + defer.returnValue(res) + + group_server_name = get_domain_from_id(group_id) + + res = yield self.transport_client.get_users_in_group( + get_domain_from_id(group_id), group_id, requester_user_id, + ) + + chunk = res["chunk"] + valid_entries = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + yield self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_entries.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["chunk"] = valid_entries + + defer.returnValue(res) + + @defer.inlineCallbacks + def join_group(self, group_id, user_id, content): + """Request to join a group + """ + if self.is_mine_id(group_id): + yield self.groups_server_handler.join_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.join_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def accept_invite(self, group_id, user_id, content): + """Accept an invite to a group + """ + if self.is_mine_id(group_id): + yield self.groups_server_handler.accept_invite( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.accept_group_invite( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) + + @defer.inlineCallbacks + def invite(self, group_id, user_id, requester_user_id, config): + """Invite a user to a group + """ + content = { + "requester_user_id": requester_user_id, + "config": config, + } + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.invite_to_group( + group_id, user_id, requester_user_id, content, + ) + else: + res = yield self.transport_client.invite_to_group( + get_domain_from_id(group_id), group_id, user_id, requester_user_id, + content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def on_invite(self, group_id, user_id, content): + """One of our users were invited to a group + """ + # TODO: Support auto join and rejection + + if not self.is_mine_id(user_id): + raise SynapseError(400, "User not on this server") + + local_profile = {} + if "profile" in content: + if "name" in content["profile"]: + local_profile["name"] = content["profile"]["name"] + if "avatar_url" in content["profile"]: + local_profile["avatar_url"] = content["profile"]["avatar_url"] + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="invite", + content={"profile": local_profile, "inviter": content["inviter"]}, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + try: + user_profile = yield self.profile_handler.get_profile(user_id) + except Exception as e: + logger.warn("No profile for user %s: %s", user_id, e) + user_profile = {} + + defer.returnValue({"state": "invite", "user_profile": user_profile}) + + @defer.inlineCallbacks + def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + """Remove a user from a group + """ + if user_id == requester_user_id: + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + # TODO: Should probably remember that we tried to leave so that we can + # retry if the group server is currently down. + + if self.is_mine_id(group_id): + res = yield self.groups_server_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + else: + content["requester_user_id"] = requester_user_id + res = yield self.transport_client.remove_user_from_group( + get_domain_from_id(group_id), group_id, requester_user_id, + user_id, content, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def user_removed_from_group(self, group_id, user_id, content): + """One of our users was removed/kicked from a group + """ + # TODO: Check if user in group + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="leave", + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + @defer.inlineCallbacks + def get_joined_groups(self, user_id): + group_ids = yield self.store.get_joined_groups(user_id) + defer.returnValue({"groups": group_ids}) + + @defer.inlineCallbacks + def get_publicised_groups_for_user(self, user_id): + if self.hs.is_mine_id(user_id): + result = yield self.store.get_publicised_groups_for_user(user_id) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + result.extend(app_service.get_groups_for_user(user_id)) + + defer.returnValue({"groups": result}) + else: + bulk_result = yield self.transport_client.bulk_get_publicised_groups( + get_domain_from_id(user_id), [user_id], + ) + result = bulk_result.get("users", {}).get(user_id) + # TODO: Verify attestations + defer.returnValue({"groups": result}) + + @defer.inlineCallbacks + def bulk_get_publicised_groups(self, user_ids, proxy=True): + destinations = {} + local_users = set() + + for user_id in user_ids: + if self.hs.is_mine_id(user_id): + local_users.add(user_id) + else: + destinations.setdefault( + get_domain_from_id(user_id), set() + ).add(user_id) + + if not proxy and destinations: + raise SynapseError(400, "Some user_ids are not local") + + results = {} + failed_results = [] + for destination, dest_user_ids in iteritems(destinations): + try: + r = yield self.transport_client.bulk_get_publicised_groups( + destination, list(dest_user_ids), + ) + results.update(r["users"]) + except Exception: + failed_results.extend(dest_user_ids) + + for uid in local_users: + results[uid] = yield self.store.get_publicised_groups_for_user( + uid + ) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + results[uid].extend(app_service.get_groups_for_user(uid)) + + defer.returnValue({"users": results}) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9efcdff1d6..8c8aedb2b8 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,17 +16,21 @@ # limitations under the License. """Utilities for interacting with Identity Servers""" + +import logging + +from canonicaljson import json + from twisted.internet import defer from synapse.api.errors import ( - MatrixCodeMessageException, CodeMessageException + CodeMessageException, + Codes, + MatrixCodeMessageException, + SynapseError, ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor -from synapse.api.errors import SynapseError, Codes -import json -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -36,6 +41,7 @@ class IdentityHandler(BaseHandler): super(IdentityHandler, self).__init__(hs) self.http_client = hs.get_simple_http_client() + self.federation_http_client = hs.get_http_client() self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers) self.trust_any_id_server_just_for_testing_do_not_use = ( @@ -58,8 +64,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def threepid_from_creds(self, creds): - yield run_on_reactor() - if 'id_server' in creds: id_server = creds['id_server'] elif 'idServer' in creds: @@ -102,7 +106,6 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def bind_threepid(self, creds, mxid): - yield run_on_reactor() logger.debug("binding threepid %r to %s", creds, mxid) data = None @@ -137,9 +140,53 @@ class IdentityHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): - yield run_on_reactor() + def unbind_threepid(self, mxid, threepid): + """ + Removes a binding from an identity server + Args: + mxid (str): Matrix user ID of binding to be removed + threepid (dict): Dict with medium & address of binding to be removed + + Returns: + Deferred[bool]: True on success, otherwise False + """ + logger.debug("unbinding threepid %r from %s", threepid, mxid) + if not self.trusted_id_servers: + logger.warn("Can't unbind threepid: no trusted ID servers set in config") + defer.returnValue(False) + + # We don't track what ID server we added 3pids on (perhaps we ought to) + # but we assume that any of the servers in the trusted list are in the + # same ID server federation, so we can pick any one of them to send the + # deletion request to. + id_server = next(iter(self.trusted_id_servers)) + + url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) + content = { + "mxid": mxid, + "threepid": threepid, + } + headers = {} + # we abuse the federation http client to sign the request, but we have to send it + # using the normal http client since we don't want the SRV lookup and want normal + # 'browser-like' HTTPS. + self.federation_http_client.sign_request( + destination=None, + method='POST', + url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), + headers_dict=headers, + content=content, + destination_is=id_server, + ) + yield self.http_client.post_json_get_json( + url, + content, + headers, + ) + defer.returnValue(True) + @defer.inlineCallbacks + def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, @@ -174,8 +221,6 @@ class IdentityHandler(BaseHandler): self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): - yield run_on_reactor() - if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 10f5f35a69..40e7580a61 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -21,20 +23,15 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig -from synapse.types import ( - UserID, StreamToken, -) +from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client from ._base import BaseHandler -import logging - - logger = logging.getLogger(__name__) @@ -151,22 +148,25 @@ class InitialSyncHandler(BaseHandler): try: if event.membership == Membership.JOIN: room_end_token = now_token.room_key - deferred_room_state = self.state_handler.get_current_state( - event.room_id + deferred_room_state = run_in_background( + self.state_handler.get_current_state, + event.room_id, ) elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) - deferred_room_state = self.store.get_state_for_events( - [event.event_id], None + deferred_room_state = run_in_background( + self.store.get_state_for_events, + [event.event_id], None, ) deferred_room_state.addCallback( lambda states: states[event.event_id] ) - (messages, token), current_state = yield preserve_context_over_deferred( + (messages, token), current_state = yield make_deferred_yieldable( defer.gatherResults( [ - preserve_fn(self.store.get_recent_events_for_room)( + run_in_background( + self.store.get_recent_events_for_room, event.room_id, limit=limit, end_token=room_end_token, @@ -180,8 +180,8 @@ class InitialSyncHandler(BaseHandler): self.store, user_id, messages ) - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) + start_token = now_token.copy_and_replace("room_key", token) + end_token = now_token.copy_and_replace("room_key", room_end_token) time_now = self.clock.time_msec() d["messages"] = { @@ -214,7 +214,7 @@ class InitialSyncHandler(BaseHandler): }) d["account_data"] = account_data_events - except: + except Exception: logger.exception("Failed to get snapshot") yield concurrently_execute(handle_room, room_list, 10) @@ -324,8 +324,8 @@ class InitialSyncHandler(BaseHandler): self.store, user_id, messages, is_peeking=is_peeking ) - start_token = StreamToken.START.copy_and_replace("room_key", token[0]) - end_token = StreamToken.START.copy_and_replace("room_key", token[1]) + start_token = StreamToken.START.copy_and_replace("room_key", token) + end_token = StreamToken.START.copy_and_replace("room_key", stream_token) time_now = self.clock.time_msec() @@ -389,25 +389,28 @@ class InitialSyncHandler(BaseHandler): receipts = [] defer.returnValue(receipts) - presence, receipts, (messages, token) = yield defer.gatherResults( - [ - preserve_fn(get_presence)(), - preserve_fn(get_receipts)(), - preserve_fn(self.store.get_recent_events_for_room)( - room_id, - limit=limit, - end_token=now_token.room_key, - ) - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + presence, receipts, (messages, token) = yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(get_presence), + run_in_background(get_receipts), + run_in_background( + self.store.get_recent_events_for_room, + room_id, + limit=limit, + end_token=now_token.room_key, + ) + ], + consumeErrors=True, + ).addErrback(unwrapFirstError), + ) messages = yield filter_events_for_client( self.store, user_id, messages, is_peeking=is_peeking, ) - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) + start_token = now_token.copy_and_replace("room_key", token) + end_token = now_token time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 24c9ffdb20..39d7724778 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 - 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,173 +13,185 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import sys + +import six +from six import iteritems, itervalues, string_types + +from canonicaljson import encode_canonical_json, json from twisted.internet import defer +from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.urls import ConsentURIBuilder from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.types import ( - UserID, RoomAlias, RoomStreamToken, -) -from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter -from synapse.util.logcontext import preserve_fn +from synapse.replication.http.send_event import send_event_to_master +from synapse.types import RoomAlias, UserID +from synapse.util.async import Linearizer +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func -from synapse.visibility import filter_events_for_client from ._base import BaseHandler -from canonicaljson import encode_canonical_json - -import logging -import random -import ujson - logger = logging.getLogger(__name__) -class MessageHandler(BaseHandler): +class MessageHandler(object): + """Contains some read only APIs to get state about a room + """ def __init__(self, hs): - super(MessageHandler, self).__init__(hs) - self.hs = hs - self.state = hs.get_state_handler() + self.auth = hs.get_auth() self.clock = hs.get_clock() - self.validator = EventValidator() - - self.pagination_lock = ReadWriteLock() - - # We arbitrarily limit concurrent event creation for a room to 5. - # This is to stop us from diverging history *too* much. - self.limiter = Limiter(max_count=5) - - self.action_generator = hs.get_action_generator() + self.state = hs.get_state_handler() + self.store = hs.get_datastore() @defer.inlineCallbacks - def purge_history(self, room_id, event_id): - event = yield self.store.get_event(event_id) + def get_room_data(self, user_id=None, room_id=None, + event_type=None, state_key="", is_guest=False): + """ Get data from a room. - if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + Args: + event : The room path event + Returns: + The path data content. + Raises: + SynapseError if something went wrong. + """ + membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( + room_id, user_id + ) - depth = event.depth + if membership == Membership.JOIN: + data = yield self.state.get_current_state( + room_id, event_type, state_key + ) + elif membership == Membership.LEAVE: + key = (event_type, state_key) + room_state = yield self.store.get_state_for_events( + [membership_event_id], [key] + ) + data = room_state[membership_event_id].get(key) - with (yield self.pagination_lock.write(room_id)): - yield self.store.delete_old_state(room_id, depth) + defer.returnValue(data) @defer.inlineCallbacks - def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True, event_filter=None): - """Get messages in a room. + def get_state_events(self, user_id, room_id, is_guest=False): + """Retrieve all state events for a given room. If the user is + joined to the room then return the current state. If the user has + left the room return the state events from when they left. Args: - requester (Requester): The user requesting messages. - room_id (str): The room they want messages from. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. - as_client_event (bool): True to get events in client-server format. - event_filter (Filter): Filter to apply to results or None + user_id(str): The user requesting state events. + room_id(str): The room ID to get all state events from. Returns: - dict: Pagination API results + A list of dicts representing state events. [{}, {}, {}] """ - user_id = requester.user.to_string() + membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( + room_id, user_id + ) - if pagin_config.from_token: - room_token = pagin_config.from_token.room_key - else: - pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_room( - room_id=room_id - ) + if membership == Membership.JOIN: + room_state = yield self.state.get_current_state(room_id) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], None ) - room_token = pagin_config.from_token.room_key - - room_token = RoomStreamToken.parse(room_token) + room_state = room_state[membership_event_id] - pagin_config.from_token = pagin_config.from_token.copy_and_replace( - "room_key", str(room_token) + now = self.clock.time_msec() + defer.returnValue( + [serialize_event(c, now) for c in room_state.values()] ) - source_config = pagin_config.get_source_config("room") + @defer.inlineCallbacks + def get_joined_members(self, requester, room_id): + """Get all the joined members in the room and their profile information. + + If the user has left the room return the state events from when they left. - with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self._check_in_room_or_world_readable( + Args: + requester(Requester): The user requesting state events. + room_id(str): The room ID to get all state events from. + Returns: + A dict of user_id to profile info + """ + user_id = requester.user.to_string() + if not requester.app_service: + # We check AS auth after fetching the room membership, as it + # requires us to pull out all joined members anyway. + membership, _ = yield self.auth.check_in_room_or_world_readable( room_id, user_id ) + if membership != Membership.JOIN: + raise NotImplementedError( + "Getting joined members after leaving is not implemented" + ) - if source_config.direction == 'b': - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - max_topo = room_token.topological - else: - max_topo = yield self.store.get_max_topological_token( - room_id, room_token.stream - ) + users_with_profile = yield self.state.get_current_user_in_room(room_id) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id - ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: - source_config.from_key = str(leave_token) + # If this is an AS, double check that they are allowed to see the members. + # This can either be because the AS user is in the room or because there + # is a user in the room that the AS is "interested in" + if requester.app_service and user_id not in users_with_profile: + for uid in users_with_profile: + if requester.app_service.is_interested_in_user(uid): + break + else: + # Loop fell through, AS has no interested users in room + raise AuthError(403, "Appservice not in room") - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo - ) + defer.returnValue({ + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in iteritems(users_with_profile) + }) - events, next_key = yield self.store.paginate_room_events( - room_id=room_id, - from_key=source_config.from_key, - to_key=source_config.to_key, - direction=source_config.direction, - limit=source_config.limit, - event_filter=event_filter, - ) - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) +class EventCreationHandler(object): + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.profile_handler = hs.get_profile_handler() + self.event_builder_factory = hs.get_event_builder_factory() + self.server_name = hs.hostname + self.ratelimiter = hs.get_ratelimiter() + self.notifier = hs.get_notifier() + self.config = hs.config - if not events: - defer.returnValue({ - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - }) - - if event_filter: - events = event_filter.filter(events) - - events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), - ) + self.http_client = hs.get_simple_http_client() - time_now = self.clock.time_msec() + # This is only used to get at ratelimit function, and maybe_kick_guest_users + self.base_handler = BaseHandler(hs) - chunk = { - "chunk": [ - serialize_event(e, time_now, as_client_event) - for e in events - ], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - } + self.pusher_pool = hs.get_pusherpool() - defer.returnValue(chunk) + # We arbitrarily limit concurrent event creation for a room to 5. + # This is to stop us from diverging history *too* much. + self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") + + self.action_generator = hs.get_action_generator() + + self.spam_checker = hs.get_spam_checker() + + if self.config.block_events_without_consent_error is not None: + self._consent_uri_builder = ConsentURIBuilder(self.config) @defer.inlineCallbacks def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_event_ids=None): + prev_events_and_hashes=None): """ Given a dict from a client, create a new event. @@ -192,50 +205,143 @@ class MessageHandler(BaseHandler): event_dict (dict): An entire event token_id (str) txn_id (str) - prev_event_ids (list): The prev event ids to use when creating the event + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. Returns: Tuple of created event (FrozenEvent), Context """ builder = self.event_builder_factory.new(event_dict) - with (yield self.limiter.queue(builder.room_id)): - self.validator.validate_new(builder) - - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) - - if membership in {Membership.JOIN, Membership.INVITE}: - # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler - content = builder.content - - try: - if "displayname" not in content: - content["displayname"] = yield profile.get_displayname(target) - if "avatar_url" not in content: - content["avatar_url"] = yield profile.get_avatar_url(target) - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", - target, e - ) + self.validator.validate_new(builder) + + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + target = UserID.from_string(builder.state_key) + + if membership in {Membership.JOIN, Membership.INVITE}: + # If event doesn't include a display name, add one. + profile = self.profile_handler + content = builder.content + + try: + if "displayname" not in content: + content["displayname"] = yield profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = yield profile.get_avatar_url(target) + except Exception as e: + logger.info( + "Failed to get profile information for %r: %s", + target, e + ) - if token_id is not None: - builder.internal_metadata.token_id = token_id + is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) + if not is_exempt: + yield self.assert_accepted_privacy_policy(requester) - if txn_id is not None: - builder.internal_metadata.txn_id = txn_id + if token_id is not None: + builder.internal_metadata.token_id = token_id - event, context = yield self._create_new_client_event( - builder=builder, - requester=requester, - prev_event_ids=prev_event_ids, - ) + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + + event, context = yield self.create_new_client_event( + builder=builder, + requester=requester, + prev_events_and_hashes=prev_events_and_hashes, + ) defer.returnValue((event, context)) + def _is_exempt_from_privacy_policy(self, builder, requester): + """"Determine if an event to be sent is exempt from having to consent + to the privacy policy + + Args: + builder (synapse.events.builder.EventBuilder): event being created + requester (Requster): user requesting this event + + Returns: + Deferred[bool]: true if the event can be sent without the user + consenting + """ + # the only thing the user can do is join the server notices room. + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + if membership == Membership.JOIN: + return self._is_server_notices_room(builder.room_id) + elif membership == Membership.LEAVE: + # the user is always allowed to leave (but not kick people) + return builder.state_key == requester.user.to_string() + return succeed(False) + + @defer.inlineCallbacks + def _is_server_notices_room(self, room_id): + if self.config.server_notices_mxid is None: + defer.returnValue(False) + user_ids = yield self.store.get_users_in_room(room_id) + defer.returnValue(self.config.server_notices_mxid in user_ids) + + @defer.inlineCallbacks + def assert_accepted_privacy_policy(self, requester): + """Check if a user has accepted the privacy policy + + Called when the given user is about to do something that requires + privacy consent. We see if the user is exempt and otherwise check that + they have given consent. If they have not, a ConsentNotGiven error is + raised. + + Args: + requester (synapse.types.Requester): + The user making the request + + Returns: + Deferred[None]: returns normally if the user has consented or is + exempt + + Raises: + ConsentNotGivenError: if the user has not given consent yet + """ + if self.config.block_events_without_consent_error is None: + return + + # exempt AS users from needing consent + if requester.app_service is not None: + return + + user_id = requester.user.to_string() + + # exempt the system notices user + if ( + self.config.server_notices_mxid is not None and + user_id == self.config.server_notices_mxid + ): + return + + u = yield self.store.get_user_by_id(user_id) + assert u is not None + if u["appservice_id"] is not None: + # users registered by an appservice are exempt + return + if u["consent_version"] == self.config.user_consent_version: + return + + consent_uri = self._consent_uri_builder.build_user_consent_uri( + requester.user.localpart, + ) + msg = self.config.block_events_without_consent_error % { + 'consent_uri': consent_uri, + } + raise ConsentNotGivenError( + msg=msg, + consent_uri=consent_uri, + ) + @defer.inlineCallbacks def send_nonmember_event(self, requester, event, context, ratelimit=True): """ @@ -253,11 +359,6 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) - # We check here if we are currently being rate limited, so that we - # don't do unnecessary work. We check again just before we actually - # send the event. - yield self.ratelimit(requester, update=False) - user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) @@ -274,12 +375,6 @@ class MessageHandler(BaseHandler): ratelimit=ratelimit, ) - if event.type == EventTypes.Message: - presence = self.hs.get_presence_handler() - # We don't want to block sending messages on any presence code. This - # matters as sometimes presence code can take a while. - preserve_fn(presence.bump_presence_active_time)(user) - @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ @@ -288,7 +383,8 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -313,145 +409,85 @@ class MessageHandler(BaseHandler): See self.create_event and self.send_nonmember_event. """ - event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id - ) - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, - ) - defer.returnValue(event) - @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): - """ Get data from a room. - - Args: - event : The room path event - Returns: - The path data content. - Raises: - SynapseError if something went wrong. - """ - membership, membership_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) - - if membership == Membership.JOIN: - data = yield self.state_handler.get_current_state( - room_id, event_type, state_key - ) - elif membership == Membership.LEAVE: - key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] + # We limit the number of concurrent event sends in a room so that we + # don't fork the DAG too much. If we don't limit then we can end up in + # a situation where event persistence can't keep up, causing + # extremities to pile up, which in turn leads to state resolution + # taking longer. + with (yield self.limiter.queue(event_dict["room_id"])): + event, context = yield self.create_event( + requester, + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id ) - data = room_state[membership_event_id].get(key) - defer.returnValue(data) + spam_error = self.spam_checker.check_event_for_spam(event) + if spam_error: + if not isinstance(spam_error, string_types): + spam_error = "Spam is not permitted here" + raise SynapseError( + 403, spam_error, Codes.FORBIDDEN + ) - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): - try: - # check_user_was_in_room will return the most recent membership - # event for the user if: - # * The user is a non-guest user, and was ever in the room - # * The user is a guest user, and has joined the room - # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - defer.returnValue((member_event.membership, member_event.event_id)) - return - except AuthError: - visibility = yield self.state_handler.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) - if ( - visibility and - visibility.content["history_visibility"] == "world_readable" - ): - defer.returnValue((Membership.JOIN, None)) - return - raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + yield self.send_nonmember_event( + requester, + event, + context, + ratelimit=ratelimit, ) + defer.returnValue(event) + @measure_func("create_new_client_event") @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): - """Retrieve all state events for a given room. If the user is - joined to the room then return the current state. If the user has - left the room return the state events from when they left. + def create_new_client_event(self, builder, requester=None, + prev_events_and_hashes=None): + """Create a new event for a local client Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. + builder (EventBuilder): + + requester (synapse.types.Requester|None): + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. + Returns: - A list of dicts representing state events. [{}, {}, {}] + Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] """ - membership, membership_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) - if membership == Membership.JOIN: - room_state = yield self.state_handler.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if prev_events_and_hashes is not None: + assert len(prev_events_and_hashes) <= 10, \ + "Attempting to create an event with %i prev_events" % ( + len(prev_events_and_hashes), ) - room_state = room_state[membership_event_id] - - now = self.clock.time_msec() - defer.returnValue( - [serialize_event(c, now) for c in room_state.values()] - ) - - @measure_func("_create_new_client_event") - @defer.inlineCallbacks - def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): - if prev_event_ids: - prev_events = yield self.store.add_event_hashes(prev_event_ids) - prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) - depth = prev_max_depth + 1 else: - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - # We want to limit the max number of prev events we point to in our - # new event - if len(latest_ret) > 10: - # Sort by reverse depth, so we point to the most recent. - latest_ret.sort(key=lambda a: -a[2]) - new_latest_ret = latest_ret[:5] - - # We also randomly point to some of the older events, to make - # sure that we don't completely ignore the older events. - if latest_ret[5:]: - sample_size = min(5, len(latest_ret[5:])) - new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) - latest_ret = new_latest_ret - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 + prev_events_and_hashes = \ + yield self.store.get_prev_events_for_room(builder.room_id) + + if prev_events_and_hashes: + depth = max([d for _, _, d in prev_events_and_hashes]) + 1 + # we cap depth of generated events, to ensure that they are not + # rejected by other servers (and so that they can be persisted in + # the db) + depth = min(depth, MAX_DEPTH) + else: + depth = 1 - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in prev_events_and_hashes + ] builder.prev_events = prev_events builder.depth = depth - state_handler = self.state_handler - - context = yield state_handler.compute_event_context(builder) + context = yield self.state.compute_event_context(builder) if requester: context.app_service = requester.app_service @@ -470,8 +506,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with state: %s", - event.event_id, context.prev_state_ids, + "Created event %s", + event.event_id, ) defer.returnValue( @@ -486,12 +522,21 @@ class MessageHandler(BaseHandler): event, context, ratelimit=True, - extra_users=[] + extra_users=[], ): - # We now need to go and hit out to wherever we need to hit out to. + """Processes a new event. This includes checking auth, persisting it, + notifying users, sending to remote servers, etc. - if ratelimit: - yield self.ratelimit(requester) + If called from a worker will hit out to the master process for final + processing. + + Args: + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ try: yield self.auth.check_from_context(event, context) @@ -501,13 +546,72 @@ class MessageHandler(BaseHandler): # Ensure that we can round trip before trying to persist in db try: - dump = ujson.dumps(event.content) - ujson.loads(dump) - except: + dump = frozendict_json_encoder.encode(event.content) + json.loads(dump) + except Exception: logger.exception("Failed to encode content: %r", event.content) raise - yield self.maybe_kick_guest_users(event, context) + yield self.action_generator.handle_push_actions_for_event( + event, context + ) + + try: + # If we're a worker we need to hit out to the master. + if self.config.worker_app: + yield send_event_to_master( + clock=self.hs.get_clock(), + store=self.store, + client=self.http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + event=event, + context=context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + return + + yield self.persist_and_notify_client_event( + requester, + event, + context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + except: # noqa: E722, as we reraise the exception this is fine. + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + tp, value, tb = sys.exc_info() + + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) + + @defer.inlineCallbacks + def persist_and_notify_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[], + ): + """Called when we have fully built the event, have already + calculated the push actions for the event, and checked auth. + + This should only be run on master. + """ + assert not self.config.worker_app + + if ratelimit: + yield self.base_handler.ratelimit(requester) + + yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Check the alias is acually valid (at this time at least) @@ -535,9 +639,11 @@ class MessageHandler(BaseHandler): e.sender == event.sender ) + current_state_ids = yield context.get_current_state_ids(self.store) + state_to_include_ids = [ e_id - for k, e_id in context.current_state_ids.iteritems() + for k, e_id in iteritems(current_state_ids) if k[0] in self.hs.config.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -551,7 +657,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for e in state_to_include.itervalues() + for e in itervalues(state_to_include) ] invitee = UserID.from_string(event.state_key) @@ -573,8 +679,9 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Redaction: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -594,15 +701,13 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) - - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + if event.type == EventTypes.Create: + prev_state_ids = yield context.get_prev_state_ids(self.store) + if prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context @@ -610,16 +715,31 @@ class MessageHandler(BaseHandler): # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + run_in_background( + self.pusher_pool.on_new_notifications, event_stream_id, max_stream_id ) - @defer.inlineCallbacks def _notify(): - yield run_on_reactor() - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + try: + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room event") + + run_in_background(_notify) - preserve_fn(_notify)() + if event.type == EventTypes.Message: + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + run_in_background(self._bump_active_time, requester.user) + + @defer.inlineCallbacks + def _bump_active_time(self, user): + try: + presence = self.hs.get_presence_handler() + yield presence.bump_presence_active_time(user) + except Exception: + logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py new file mode 100644 index 0000000000..b2849783ed --- /dev/null +++ b/synapse/handlers/pagination.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 - 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer +from twisted.python.failure import Failure + +from synapse.api.constants import Membership +from synapse.api.errors import SynapseError +from synapse.events.utils import serialize_event +from synapse.types import RoomStreamToken +from synapse.util.async import ReadWriteLock +from synapse.util.logcontext import run_in_background +from synapse.util.stringutils import random_string +from synapse.visibility import filter_events_for_client + +logger = logging.getLogger(__name__) + + +class PurgeStatus(object): + """Object tracking the status of a purge request + + This class contains information on the progress of a purge request, for + return by get_purge_status. + + Attributes: + status (int): Tracks whether this request has completed. One of + STATUS_{ACTIVE,COMPLETE,FAILED} + """ + + STATUS_ACTIVE = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + + STATUS_TEXT = { + STATUS_ACTIVE: "active", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + } + + def __init__(self): + self.status = PurgeStatus.STATUS_ACTIVE + + def asdict(self): + return { + "status": PurgeStatus.STATUS_TEXT[self.status] + } + + +class PaginationHandler(object): + """Handles pagination and purge history requests. + + These are in the same handler due to the fact we need to block clients + paginating during a purge. + """ + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + self.pagination_lock = ReadWriteLock() + self._purges_in_progress_by_room = set() + # map from purge id to PurgeStatus + self._purges_by_id = {} + + def start_purge_history(self, room_id, token, + delete_local_events=False): + """Start off a history purge on a room. + + Args: + room_id (str): The room to purge from + + token (str): topological token to delete events before + delete_local_events (bool): True to delete local events as well as + remote ones + + Returns: + str: unique ID for this purge transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, + "History purge already in progress for %s" % (room_id, ), + ) + + purge_id = random_string(16) + + # we log the purge_id here so that it can be tied back to the + # request id in the log lines. + logger.info("[purge] starting purge_id %s", purge_id) + + self._purges_by_id[purge_id] = PurgeStatus() + run_in_background( + self._purge_history, + purge_id, room_id, token, delete_local_events, + ) + return purge_id + + @defer.inlineCallbacks + def _purge_history(self, purge_id, room_id, token, + delete_local_events): + """Carry out a history purge on a room. + + Args: + purge_id (str): The id for this purge + room_id (str): The room to purge from + token (str): topological token to delete events before + delete_local_events (bool): True to delete local events as well as + remote ones + + Returns: + Deferred + """ + self._purges_in_progress_by_room.add(room_id) + try: + with (yield self.pagination_lock.write(room_id)): + yield self.store.purge_history( + room_id, token, delete_local_events, + ) + logger.info("[purge] complete") + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE + except Exception: + logger.error("[purge] failed: %s", Failure().getTraceback().rstrip()) + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the purge from the list 24 hours after it completes + def clear_purge(): + del self._purges_by_id[purge_id] + self.hs.get_reactor().callLater(24 * 3600, clear_purge) + + def get_purge_status(self, purge_id): + """Get the current status of an active purge + + Args: + purge_id (str): purge_id returned by start_purge_history + + Returns: + PurgeStatus|None + """ + return self._purges_by_id.get(purge_id) + + @defer.inlineCallbacks + def get_messages(self, requester, room_id=None, pagin_config=None, + as_client_event=True, event_filter=None): + """Get messages in a room. + + Args: + requester (Requester): The user requesting messages. + room_id (str): The room they want messages from. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config rules to apply, if any. + as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None + Returns: + dict: Pagination API results + """ + user_id = requester.user.to_string() + + if pagin_config.from_token: + room_token = pagin_config.from_token.room_key + else: + pagin_config.from_token = ( + yield self.hs.get_event_sources().get_current_token_for_room( + room_id=room_id + ) + ) + room_token = pagin_config.from_token.room_key + + room_token = RoomStreamToken.parse(room_token) + + pagin_config.from_token = pagin_config.from_token.copy_and_replace( + "room_key", str(room_token) + ) + + source_config = pagin_config.get_source_config("room") + + with (yield self.pagination_lock.read(room_id)): + membership, member_event_id = yield self.auth.check_in_room_or_world_readable( + room_id, user_id + ) + + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token( + room_id, room_token.stream + ) + + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < max_topo: + source_config.from_key = str(leave_token) + + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, max_topo + ) + + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, + ) + + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key + ) + + if not events: + defer.returnValue({ + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + }) + + if event_filter: + events = event_filter.filter(events) + + events = yield filter_events_for_client( + self.store, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + time_now = self.clock.time_msec() + + chunk = { + "chunk": [ + serialize_event(e, time_now, as_client_event) + for e in events + ], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } + + defer.returnValue(chunk) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c7c0b0a1e2..3732830194 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -22,41 +22,44 @@ The methods that define policy are: - should_notify """ -from twisted.internet import defer, reactor +import logging from contextlib import contextmanager -from synapse.api.errors import SynapseError +from six import iteritems, itervalues + +from prometheus_client import Counter + +from twisted.internet import defer + from synapse.api.constants import PresenceState +from synapse.api.errors import SynapseError +from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState - -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import UserID, get_domain_from_id from synapse.util.async import Linearizer -from synapse.util.logcontext import preserve_fn +from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -from synapse.types import UserID, get_domain_from_id -import synapse.metrics - -import logging - logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) -notified_presence_counter = metrics.register_counter("notified_presence") -federation_presence_out_counter = metrics.register_counter("federation_presence_out") -presence_updates_counter = metrics.register_counter("presence_updates") -timers_fired_counter = metrics.register_counter("timers_fired") -federation_presence_counter = metrics.register_counter("federation_presence") -bump_active_time_counter = metrics.register_counter("bump_active_time") +notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") +federation_presence_out_counter = Counter( + "synapse_handler_presence_federation_presence_out", "") +presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") +timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") +federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "") +bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") -get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) +get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) -notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"]) -state_transition_counter = metrics.register_counter( - "state_transition", labels=["from", "to"] +notify_reason_counter = Counter( + "synapse_handler_presence_notify_reason", "", ["reason"]) +state_transition_counter = Counter( + "synapse_handler_presence_state_transition", "", ["from", "to"] ) @@ -87,35 +90,40 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer): + """ self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.replication = hs.get_replication_layer() self.federation = hs.get_federation_sender() - self.state = hs.get_state_handler() - self.replication.register_edu_handler( + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( "m.presence", self.incoming_presence ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), @@ -136,8 +144,9 @@ class PresenceHandler(object): for state in active_presence } - metrics.register_callback( - "user_to_current_state_size", lambda: len(self.user_to_current_state) + LaterGauge( + "synapse_handlers_presence_user_to_current_state_size", "", [], + lambda: len(self.user_to_current_state) ) now = self.clock.time_msec() @@ -169,7 +178,7 @@ class PresenceHandler(object): # have not yet been persisted self.unpersisted_users_changes = set() - reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown) self.serial_to_user = {} self._next_serial = 1 @@ -207,7 +216,8 @@ class PresenceHandler(object): 60 * 1000, ) - metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) + LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], + lambda: len(self.wheel_timer)) @defer.inlineCallbacks def _on_shutdown(self): @@ -254,6 +264,14 @@ class PresenceHandler(object): logger.info("Finished _persist_unpersisted_changes") @defer.inlineCallbacks + def _update_states_and_catch_exception(self, new_states): + try: + res = yield self._update_states(new_states) + defer.returnValue(res) + except Exception: + logger.exception("Error updating presence") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -302,11 +320,11 @@ class PresenceHandler(object): # TODO: We should probably ensure there are no races hereafter - presence_updates_counter.inc_by(len(new_states)) + presence_updates_counter.inc(len(new_states)) if to_notify: - notified_presence_counter.inc_by(len(to_notify)) - yield self._persist_and_notify(to_notify.values()) + notified_presence_counter.inc(len(to_notify)) + yield self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= set(s.user_id for s in new_states) self.unpersisted_users_changes -= set(to_notify.keys()) @@ -316,7 +334,7 @@ class PresenceHandler(object): if user_id not in to_notify } if to_federation_ping: - federation_presence_out_counter.inc_by(len(to_federation_ping)) + federation_presence_out_counter.inc(len(to_federation_ping)) self._push_to_remotes(to_federation_ping.values()) @@ -354,7 +372,7 @@ class PresenceHandler(object): for user_id in users_to_check ] - timers_fired_counter.inc_by(len(states)) + timers_fired_counter.inc(len(states)) changes = handle_timeouts( states, @@ -363,8 +381,8 @@ class PresenceHandler(object): now=now, ) - preserve_fn(self._update_states)(changes) - except: + run_in_background(self._update_states_and_catch_exception, changes) + except Exception: logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks @@ -421,20 +439,23 @@ class PresenceHandler(object): @defer.inlineCallbacks def _end(): - if affect_presence: + try: self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) yield self._update_states([prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec(), )]) + except Exception: + logger.exception("Error updating presence after sync") @contextmanager def _user_syncing(): try: yield finally: - preserve_fn(_end)() + if affect_presence: + run_in_background(_end) defer.returnValue(_user_syncing()) @@ -452,61 +473,6 @@ class PresenceHandler(object): return syncing_user_ids @defer.inlineCallbacks - def update_external_syncs(self, process_id, syncing_user_ids): - """Update the syncing users for an external process - - Args: - process_id(str): An identifier for the process the users are - syncing against. This allows synapse to process updates - as user start and stop syncing against a given process. - syncing_user_ids(set(str)): The set of user_ids that are - currently syncing on that server. - """ - - # Grab the previous list of user_ids that were syncing on that process - prev_syncing_user_ids = ( - self.external_process_to_current_syncs.get(process_id, set()) - ) - # Grab the current presence state for both the users that are syncing - # now and the users that were syncing before this update. - prev_states = yield self.current_state_for_users( - syncing_user_ids | prev_syncing_user_ids - ) - updates = [] - time_now_ms = self.clock.time_msec() - - # For each new user that is syncing check if we need to mark them as - # being online. - for new_user_id in syncing_user_ids - prev_syncing_user_ids: - prev_state = prev_states[new_user_id] - if prev_state.state == PresenceState.OFFLINE: - updates.append(prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=time_now_ms, - last_user_sync_ts=time_now_ms, - )) - else: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - )) - - # For each user that is still syncing or stopped syncing update the - # last sync time so that we will correctly apply the grace period when - # they stop syncing. - for old_user_id in prev_syncing_user_ids: - prev_state = prev_states[old_user_id] - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - )) - - yield self._update_states(updates) - - # Update the last updated time for the process. We expire the entries - # if we don't receive an update in the given timeframe. - self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - self.external_process_to_current_syncs[process_id] = syncing_user_ids - - @defer.inlineCallbacks def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): """Update the syncing users for an external process as a delta. @@ -569,7 +535,7 @@ class PresenceHandler(object): prev_state.copy_and_replace( last_user_sync_ts=time_now_ms, ) - for prev_state in prev_states.itervalues() + for prev_state in itervalues(prev_states) ]) self.external_process_last_updated_ms.pop(process_id, None) @@ -592,14 +558,14 @@ class PresenceHandler(object): for user_id in user_ids } - missing = [user_id for user_id, state in states.iteritems() if not state] + missing = [user_id for user_id, state in iteritems(states) if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.iteritems() if not state] + missing = [user_id for user_id, state in iteritems(states) if not state] if missing: new = { user_id: UserPresenceState.default(user_id) @@ -695,7 +661,7 @@ class PresenceHandler(object): updates.append(prev_state.copy_and_replace(**new_fields)) if updates: - federation_presence_counter.inc_by(len(updates)) + federation_presence_counter.inc(len(updates)) yield self._update_states(updates) @defer.inlineCallbacks @@ -720,7 +686,7 @@ class PresenceHandler(object): """ updates = yield self.current_state_for_users(target_user_ids) - updates = updates.values() + updates = list(updates.values()) for user_id in set(target_user_ids) - set(u.user_id for u in updates): updates.append(UserPresenceState.default(user_id)) @@ -786,11 +752,11 @@ class PresenceHandler(object): self._push_to_remotes([state]) else: user_ids = yield self.store.get_users_in_room(room_id) - user_ids = filter(self.is_mine_id, user_ids) + user_ids = list(filter(self.is_mine_id, user_ids)) states = yield self.current_state_for_users(user_ids) - self._push_to_remotes(states.values()) + self._push_to_remotes(list(states.values())) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): @@ -970,28 +936,28 @@ def should_notify(old_state, new_state): return False if old_state.status_msg != new_state.status_msg: - notify_reason_counter.inc("status_msg_change") + notify_reason_counter.labels("status_msg_change").inc() return True if old_state.state != new_state.state: - notify_reason_counter.inc("state_change") - state_transition_counter.inc(old_state.state, new_state.state) + notify_reason_counter.labels("state_change").inc() + state_transition_counter.labels(old_state.state, new_state.state).inc() return True if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: - notify_reason_counter.inc("current_active_change") + notify_reason_counter.labels("current_active_change").inc() return True if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Only notify about last active bumps if we're not currently acive if not new_state.currently_active: - notify_reason_counter.inc("last_active_change_online") + notify_reason_counter.labels("last_active_change_online").inc() return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. - notify_reason_counter.inc("last_active_change_not_online") + notify_reason_counter.labels("last_active_change_not_online").inc() return True return False @@ -1065,14 +1031,14 @@ class PresenceEventSource(object): if changed is not None and len(changed) < 500: # For small deltas, its quicker to get all changes and then # work out if we share a room or they're in our presence list - get_updates_counter.inc("stream") + get_updates_counter.labels("stream").inc() for other_user_id in changed: if other_user_id in users_interested_in: user_ids_changed.add(other_user_id) else: # Too many possible updates. Find all users we can see and check # if any of them have changed. - get_updates_counter.inc("full") + get_updates_counter.labels("full").inc() if from_key: user_ids_changed = stream_change_cache.get_entities_changed( @@ -1084,10 +1050,10 @@ class PresenceEventSource(object): updates = yield presence.current_state_for_users(user_ids_changed) if include_offline: - defer.returnValue((updates.values(), max_token)) + defer.returnValue((list(updates.values()), max_token)) else: defer.returnValue(([ - s for s in updates.itervalues() + s for s in itervalues(updates) if s.state != PresenceState.OFFLINE ], max_token)) @@ -1145,7 +1111,7 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): if new_state: changes[state.user_id] = new_state - return changes.values() + return list(changes.values()) def handle_timeout(state, is_mine, syncing_user_ids, now): @@ -1199,7 +1165,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): ) changed = True else: - # We expect to be poked occaisonally by the other side. + # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: @@ -1344,11 +1310,11 @@ def get_interested_remotes(store, states, state_handler): # hosts in those rooms. room_ids_to_states, users_to_states = yield get_interested_parties(store, states) - for room_id, states in room_ids_to_states.iteritems(): + for room_id, states in iteritems(room_ids_to_states): hosts = yield state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) - for user_id, states in users_to_states.iteritems(): + for user_id, states in iteritems(users_to_states): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7abee98dea..859f6d2b2e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,25 +17,88 @@ import logging from twisted.internet import defer -import synapse.types -from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID -from ._base import BaseHandler +from synapse.api.errors import AuthError, CodeMessageException, SynapseError +from synapse.types import UserID, get_domain_from_id +from ._base import BaseHandler logger = logging.getLogger(__name__) class ProfileHandler(BaseHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): super(ProfileHandler, self).__init__(hs) - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "profile", self.on_profile_query ) + self.user_directory_handler = hs.get_user_directory_handler() + + if hs.config.worker_app is None: + self.clock.looping_call( + self._update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + + @defer.inlineCallbacks + def get_profile(self, user_id): + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + try: + result = yield self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + defer.returnValue(result) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get displayname") + + raise + + @defer.inlineCallbacks + def get_profile_from_cache(self, user_id): + """Get the profile information from our local cache. If the user is + ours then the profile information will always be corect. Otherwise, + it may be out of date/missing. + """ + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + profile = yield self.store.get_from_remote_profile_cache(user_id) + defer.returnValue(profile or {}) + @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -60,7 +123,7 @@ class ProfileHandler(BaseHandler): logger.exception("Failed to get displayname") raise - except: + except Exception: logger.exception("Failed to get displayname") else: defer.returnValue(result["displayname"]) @@ -82,7 +145,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -107,7 +176,7 @@ class ProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get avatar_url") raise - except: + except Exception: logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -126,7 +195,13 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - yield self._update_join_states(requester) + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + + yield self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -151,28 +226,24 @@ class ProfileHandler(BaseHandler): defer.returnValue(response) @defer.inlineCallbacks - def _update_join_states(self, requester): - user = requester.user - if not self.hs.is_mine(user): + def _update_join_states(self, requester, target_user): + if not self.hs.is_mine(target_user): return yield self.ratelimit(requester) room_ids = yield self.store.get_rooms_for_user( - user.to_string(), + target_user.to_string(), ) for room_id in room_ids: - handler = self.hs.get_handlers().room_member_handler + handler = self.hs.get_room_member_handler() try: - # Assume the user isn't a guest because we don't let guests set - # profile or avatar data. - # XXX why are we recreating `requester` here for each room? - # what was wrong with the `requester` we were passed? - requester = synapse.types.create_requester(user) + # Assume the target_user isn't a guest, + # because we don't let guests set profile or avatar data. yield handler.update_membership( requester, - user, + target_user, room_id, "join", # We treat a profile update like a join. ratelimit=False, # Try to hide that these events aren't atomic. @@ -182,3 +253,44 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e.message) ) + + def _update_remote_profile_cache(self): + """Called periodically to check profiles of remote users we haven't + checked in a while. + """ + entries = yield self.store.get_remote_profile_cache_entries_that_expire( + last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS + ) + + for user_id, displayname, avatar_url in entries: + is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + user_id, + ) + if not is_subscribed: + yield self.store.maybe_delete_remote_profile_cache(user_id) + continue + + try: + profile = yield self.federation.make_query( + destination=get_domain_from_id(user_id), + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + except Exception: + logger.exception("Failed to get avatar_url") + + yield self.store.update_remote_profile_cache( + user_id, displayname, avatar_url + ) + continue + + new_name = profile.get("displayname") + new_avatar = profile.get("avatar_url") + + # We always hit update to update the last_check timestamp + yield self.store.update_remote_profile_cache( + user_id, new_name, new_avatar + ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index b5b0303d54..995460f82a 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseHandler +import logging from twisted.internet import defer from synapse.util.async import Linearizer -import logging +from ._base import BaseHandler + logger = logging.getLogger(__name__) @@ -41,9 +42,9 @@ class ReadMarkerHandler(BaseHandler): """ with (yield self.read_marker_linearizer.queue((room_id, user_id))): - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - - existing_read_marker = account_data.get("m.fully_read", None) + existing_read_marker = yield self.store.get_account_data_for_room_and_type( + user_id, room_id, "m.fully_read", + ) should_update = True diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e1cd3a48e9..cb905a3903 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,16 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ._base import BaseHandler +import logging from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext from synapse.types import get_domain_from_id +from synapse.util import logcontext +from synapse.util.logcontext import PreserveLoggingContext -import logging - +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -34,7 +33,7 @@ class ReceiptsHandler(BaseHandler): self.store = hs.get_datastore() self.hs = hs self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() @@ -59,6 +58,8 @@ class ReceiptsHandler(BaseHandler): is_new = yield self._handle_new_receipts([receipt]) if is_new: + # fire off a process in the background to send the receipt to + # remote servers self._push_remotes([receipt]) @defer.inlineCallbacks @@ -126,42 +127,46 @@ class ReceiptsHandler(BaseHandler): defer.returnValue(True) + @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks def _push_remotes(self, receipts): """Given a list of receipts, works out which remote servers should be poked and pokes them. """ - # TODO: Some of this stuff should be coallesced. - for receipt in receipts: - room_id = receipt["room_id"] - receipt_type = receipt["receipt_type"] - user_id = receipt["user_id"] - event_ids = receipt["event_ids"] - data = receipt["data"] - - users = yield self.state.get_current_user_in_room(room_id) - remotedomains = set(get_domain_from_id(u) for u in users) - remotedomains = remotedomains.copy() - remotedomains.discard(self.server_name) - - logger.debug("Sending receipt to: %r", remotedomains) - - for domain in remotedomains: - self.federation.send_edu( - destination=domain, - edu_type="m.receipt", - content={ - room_id: { - receipt_type: { - user_id: { - "event_ids": event_ids, - "data": data, + try: + # TODO: Some of this stuff should be coallesced. + for receipt in receipts: + room_id = receipt["room_id"] + receipt_type = receipt["receipt_type"] + user_id = receipt["user_id"] + event_ids = receipt["event_ids"] + data = receipt["data"] + + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) + + logger.debug("Sending receipt to: %r", remotedomains) + + for domain in remotedomains: + self.federation.send_edu( + destination=domain, + edu_type="m.receipt", + content={ + room_id: { + receipt_type: { + user_id: { + "event_ids": event_ids, + "data": data, + } } - } + }, }, - }, - key=(room_id, receipt_type, user_id), - ) + key=(room_id, receipt_type, user_id), + ) + except Exception: + logger.exception("Error pushing receipts to remote servers") @defer.inlineCallbacks def get_receipts_for_room(self, room_id, to_key): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ee3a2269a8..7caff0cbc8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,16 +15,22 @@ """Contains functions for registering clients.""" import logging -import urllib from twisted.internet import defer +from synapse import types from synapse.api.errors import ( - AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError + AuthError, + Codes, + InvalidCaptchaError, + RegistrationError, + SynapseError, ) from synapse.http.client import CaptchaServerHttpClient -from synapse.types import UserID -from synapse.util.async import run_on_reactor +from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.util.async import Linearizer +from synapse.util.threepids import check_3pid_allowed + from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -33,24 +39,35 @@ logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer): + """ super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self.profile_handler = hs.get_profile_handler() + self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None self.macaroon_gen = hs.get_macaroon_generator() + self._generate_user_id_linearizer = Linearizer( + name="_generate_user_id_linearizer", + ) + self._server_notices_mxid = hs.config.server_notices_mxid + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): - yield run_on_reactor() - - if urllib.quote(localpart.encode('utf-8')) != localpart: + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - "User ID can only contain characters a-z, 0-9, or '_-./'", + "User ID can only contain characters a-z, 0-9, or '=_-./'", Codes.INVALID_USERNAME ) @@ -80,7 +97,7 @@ class RegistrationHandler(BaseHandler): "A different user ID has already been registered for this session", ) - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: @@ -127,10 +144,9 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield run_on_reactor() password_hash = None if password: - password_hash = self.auth_handler().hash(password) + password_hash = yield self.auth_handler().hash(password) if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -165,6 +181,13 @@ class RegistrationHandler(BaseHandler): ), admin=admin, ) + + if self.hs.config.user_directory_search_all_users: + profile = yield self.store.get_profileinfo(localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + else: # autogen a sequential user ID attempts = 0 @@ -192,10 +215,17 @@ class RegistrationHandler(BaseHandler): token = None attempts += 1 + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = create_requester(user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding # rather than there being consistent matrix-wide ones, so we don't. - defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -253,11 +283,10 @@ class RegistrationHandler(BaseHandler): """ Registers email_id as SAML2 Based Auth. """ - if urllib.quote(localpart) != localpart: + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - "User ID must only contain characters which do not" - " require URL encoding." + "User ID can only contain characters a-z, 0-9, or '=_-./'", ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -286,12 +315,12 @@ class RegistrationHandler(BaseHandler): """ for c in threepidCreds: - logger.info("validating theeepidcred sid %s on id server %s", + logger.info("validating threepidcred sid %s on id server %s", c['sid'], c['idServer']) try: identity_handler = self.hs.get_handlers().identity_handler threepid = yield identity_handler.threepid_from_creds(c) - except: + except Exception: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") @@ -300,6 +329,11 @@ class RegistrationHandler(BaseHandler): logger.info("got threepid with medium '%s' and address '%s'", threepid['medium'], threepid['address']) + if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): + raise RegistrationError( + 403, "Third party identifier is not allowed" + ) + @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): """Links emails with a user ID and informs an identity server. @@ -314,6 +348,14 @@ class RegistrationHandler(BaseHandler): yield identity_handler.bind_threepid(c, user_id) def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): + # don't allow people to register the server notices mxid + if self._server_notices_mxid is not None: + if user_id == self._server_notices_mxid: + raise SynapseError( + 400, "This user ID is reserved.", + errcode=Codes.EXCLUSIVE + ) + # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() @@ -332,9 +374,11 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _generate_user_id(self, reseed=False): if reseed or self._next_generated_user_id is None: - self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() - ) + with (yield self._generate_user_id_linearizer.queue(())): + if reseed or self._next_generated_user_id is None: + self._next_generated_user_id = ( + yield self.store.find_next_generated_user_id_localpart() + ) id = self._next_generated_user_id self._next_generated_user_id += 1 @@ -391,8 +435,6 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield run_on_reactor() - if localpart is None: raise SynapseError(400, "Request must include user id") @@ -418,13 +460,12 @@ class RegistrationHandler(BaseHandler): create_profile_with_localpart=user.localpart, ) else: - yield self.store.user_delete_access_tokens(user_id=user_id) + yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) - profile_handler = self.hs.get_handlers().profile_handler - yield profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, displayname, by_admin=True, ) @@ -434,16 +475,59 @@ class RegistrationHandler(BaseHandler): return self.hs.get_auth_handler() @defer.inlineCallbacks - def guest_access_token_for(self, medium, address, inviter_user_id): + def get_or_register_3pid_guest(self, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - defer.returnValue(access_token) + user_info = yield self.auth.get_user_by_access_token( + access_token + ) + + defer.returnValue((user_info["user"].to_string(), access_token)) - _, access_token = yield self.register( + user_id, access_token = yield self.register( generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id ) - defer.returnValue(access_token) + + defer.returnValue((user_id, access_token)) + + @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + room_member_handler = self.hs.get_room_member_handler() + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5698d28088..6150b7e226 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,23 +15,20 @@ # limitations under the License. """Contains functions for performing events on rooms.""" -from twisted.internet import defer +import logging +import math +import string +from collections import OrderedDict -from ._base import BaseHandler +from twisted.internet import defer -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken -from synapse.api.constants import ( - EventTypes, JoinRules, RoomCreationPreset -) -from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.types import RoomAlias, RoomID, RoomStreamToken, UserID from synapse.util import stringutils from synapse.visibility import filter_events_for_client -from collections import OrderedDict - -import logging -import math -import string +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -60,21 +58,43 @@ class RoomCreationHandler(BaseHandler): }, } + def __init__(self, hs): + super(RoomCreationHandler, self).__init__(hs) + + self.spam_checker = hs.get_spam_checker() + self.event_creation_handler = hs.get_event_creation_handler() + @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True): + def create_room(self, requester, config, ratelimit=True, + creator_join_profile=None): """ Creates a new room. Args: - requester (Requester): The user who requested the room creation. + requester (synapse.types.Requester): + The user who requested the room creation. config (dict) : A dict of configuration options. + ratelimit (bool): set to False to disable the rate limiter + + creator_join_profile (dict|None): + Set to override the displayname and avatar for the creating + user in this room. If unset, displayname and avatar will be + derived from the user's profile. If set, should contain the + values to go in the body of the 'join' event (typically + `avatar_url` and/or `displayname`. + Returns: - The new room ID. + Deferred[dict]: + a dict containing the keys `room_id` and, if an alias was + requested, `room_alias`. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. """ user_id = requester.user.to_string() + if not self.spam_checker.user_may_create_room(user_id): + raise SynapseError(403, "You are not permitted to create rooms") + if ratelimit: yield self.ratelimit(requester) @@ -83,7 +103,7 @@ class RoomCreationHandler(BaseHandler): if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias.create( + room_alias = RoomAlias( config["room_alias_name"], self.hs.hostname, ) @@ -92,7 +112,11 @@ class RoomCreationHandler(BaseHandler): ) if mapping: - raise SynapseError(400, "Room alias already taken") + raise SynapseError( + 400, + "Room alias already taken", + Codes.ROOM_IN_USE + ) else: room_alias = None @@ -100,9 +124,13 @@ class RoomCreationHandler(BaseHandler): for i in invite_list: try: UserID.from_string(i) - except: + except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) + yield self.event_creation_handler.assert_accepted_privacy_policy( + requester, + ) + invite_3pid_list = config.get("invite_3pid", []) visibility = config.get("visibility", None) @@ -115,7 +143,7 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID.create( + gen_room_id = RoomID( random_string, self.hs.hostname, ) @@ -155,25 +183,24 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) - msg_handler = self.hs.get_handlers().message_handler - room_member_handler = self.hs.get_handlers().room_member_handler + room_member_handler = self.hs.get_room_member_handler() yield self._send_events_for_new_room( requester, room_id, - msg_handler, room_member_handler, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, creation_content=creation_content, room_alias=room_alias, - power_level_content_override=config.get("power_level_content_override", {}) + power_level_content_override=config.get("power_level_content_override", {}), + creator_join_profile=creator_join_profile, ) if "name" in config: name = config["name"] - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -186,7 +213,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -197,12 +224,12 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) - content = {} - is_direct = config.get("is_direct", None) - if is_direct: - content["is_direct"] = is_direct - for invitee in invite_list: + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + yield room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -216,7 +243,7 @@ class RoomCreationHandler(BaseHandler): id_server = invite_3pid["id_server"] address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_handlers().room_member_handler.do_3pid_invite( + yield self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -241,7 +268,6 @@ class RoomCreationHandler(BaseHandler): self, creator, # A Requester object. room_id, - msg_handler, room_member_handler, preset_config, invite_list, @@ -249,6 +275,7 @@ class RoomCreationHandler(BaseHandler): creation_content, room_alias, power_level_content_override, + creator_join_profile, ): def create(etype, content, **kwargs): e = { @@ -264,7 +291,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False @@ -292,6 +319,7 @@ class RoomCreationHandler(BaseHandler): room_id, "join", ratelimit=False, + content=creator_join_profile, ) # We treat the power levels override specially as this needs to be one @@ -367,7 +395,11 @@ class RoomCreationHandler(BaseHandler): ) -class RoomContextHandler(BaseHandler): +class RoomContextHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit): """Retrieves events, pagination tokens and state around a given event @@ -428,7 +460,7 @@ class RoomContextHandler(BaseHandler): state = yield self.store.get_state_for_events( [last_event_id], None ) - results["state"] = state[last_event_id].values() + results["state"] = list(state[last_event_id].values()) results["start"] = now_token.copy_and_replace( "room_key", results["start"] @@ -468,12 +500,9 @@ class RoomEventSource(object): user.to_string() ) if app_service: - events, end_key = yield self.store.get_appservice_room_stream( - service=app_service, - from_key=from_key, - to_key=to_key, - limit=limit, - ) + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() else: room_events = yield self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 516cd9a6ac..828229f5c3 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,23 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +from collections import namedtuple -from ._base import BaseHandler +from six import iteritems +from six.moves import range + +import msgpack +from unpaddedbase64 import decode_base64, encode_base64 + +from twisted.internet import defer -from synapse.api.constants import ( - EventTypes, JoinRules, -) +from synapse.api.constants import EventTypes, JoinRules +from synapse.types import ThirdPartyInstanceID from synapse.util.async import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache -from synapse.types import ThirdPartyInstanceID - -from collections import namedtuple -from unpaddedbase64 import encode_base64, decode_base64 -import logging -import msgpack +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -37,18 +38,19 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 # This is used to indicate we should only return rooms published to the main list. -EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) +EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache(hs) - self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + self.response_cache = ResponseCache(hs, "room_list") + self.remote_response_cache = ResponseCache(hs, "remote_room_list", + timeout_ms=30 * 1000) def get_local_public_room_list(self, limit=None, since_token=None, search_filter=None, - network_tuple=EMTPY_THIRD_PARTY_ID,): + network_tuple=EMPTY_THIRD_PARTY_ID,): """Generate a local public room list. There are multiple different lists: the main one plus one per third @@ -70,25 +72,22 @@ class RoomListHandler(BaseHandler): if search_filter: # We explicitly don't bother caching searches or requests for # appservice specific lists. + logger.info("Bypassing cache as search request.") return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, ) key = (limit, since_token, network_tuple) - result = self.response_cache.get(key) - if not result: - result = self.response_cache.set( - key, - self._get_public_room_list( - limit, since_token, network_tuple=network_tuple - ) - ) - return result + return self.response_cache.wrap( + key, + self._get_public_room_list, + limit, since_token, network_tuple=network_tuple, + ) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, search_filter=None, - network_tuple=EMTPY_THIRD_PARTY_ID,): + network_tuple=EMPTY_THIRD_PARTY_ID,): if since_token and since_token != "END": since_token = RoomListNextBatch.from_token(since_token) else: @@ -149,6 +148,8 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) + logger.info("Getting ordering for %i rooms since %s", + len(room_ids), stream_token) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -176,34 +177,43 @@ class RoomListHandler(BaseHandler): rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan.reverse() - # Actually generate the entries. _append_room_entry_to_chunk will append to - # chunk but will stop if len(chunk) > limit - chunk = [] - if limit and not search_filter: + logger.info("After sorting and filtering, %i rooms remain", + len(rooms_to_scan)) + + # _append_room_entry_to_chunk will append to chunk but will stop if + # len(chunk) > limit + # + # Normally we will generate enough results on the first iteration here, + # but if there is a search filter, _append_room_entry_to_chunk may + # filter some results out, in which case we loop again. + # + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + # + # XXX if there is no limit, we may end up DoSing the server with + # calls to get_current_state_ids for every single room on the + # server. Surely we should cap this somehow? + # + if limit: step = limit + 1 - for i in xrange(0, len(rooms_to_scan), step): - # We iterate here because the vast majority of cases we'll stop - # at first iteration, but occaisonally _append_room_entry_to_chunk - # won't append to the chunk and so we need to loop again. - # We don't want to scan over the entire range either as that - # would potentially waste a lot of work. - yield concurrently_execute( - lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter - ), - rooms_to_scan[i:i + step], 10 - ) - if len(chunk) >= limit + 1: - break else: + # step cannot be zero + step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 + + chunk = [] + for i in range(0, len(rooms_to_scan), step): + batch = rooms_to_scan[i:i + step] + logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( r, rooms_to_num_joined[r], chunk, limit, search_filter ), - rooms_to_scan, 5 + batch, 5, ) + logger.info("Now %i rooms in result", len(chunk)) + if len(chunk) >= limit + 1: + break chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) @@ -276,13 +286,14 @@ class RoomListHandler(BaseHandler): # We've already got enough, so lets just drop it. return - result = yield self._generate_room_entry(room_id, num_joined_users) + result = yield self.generate_room_entry(room_id, num_joined_users) if result and _matches_room_entry(result, search_filter): chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def _generate_room_entry(self, room_id, num_joined_users, cache_context): + def generate_room_entry(self, room_id, num_joined_users, cache_context, + with_alias=True, allow_private=False): """Returns the entry for a room """ result = { @@ -295,7 +306,7 @@ class RoomListHandler(BaseHandler): ) event_map = yield self.store.get_events([ - event_id for key, event_id in current_state_ids.iteritems() + event_id for key, event_id in iteritems(current_state_ids) if key[0] in ( EventTypes.JoinRules, EventTypes.Name, @@ -316,14 +327,15 @@ class RoomListHandler(BaseHandler): join_rules_event = current_state.get((EventTypes.JoinRules, "")) if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) - if join_rule and join_rule != JoinRules.PUBLIC: + if not allow_private and join_rule and join_rule != JoinRules.PUBLIC: defer.returnValue(None) - aliases = yield self.store.get_aliases_for_room( - room_id, on_invalidate=cache_context.invalidate - ) - if aliases: - result["aliases"] = aliases + if with_alias: + aliases = yield self.store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) + if aliases: + result["aliases"] = aliases name_event = yield current_state.get((EventTypes.Name, "")) if name_event: @@ -391,7 +403,7 @@ class RoomListHandler(BaseHandler): def _get_remote_list_cached(self, server_name, limit=None, since_token=None, search_filter=None, include_all_networks=False, third_party_instance_id=None,): - repl_layer = self.hs.get_replication_layer() + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( @@ -404,18 +416,14 @@ class RoomListHandler(BaseHandler): server_name, limit, since_token, include_all_networks, third_party_instance_id, ) - result = self.remote_response_cache.get(key) - if not result: - result = self.remote_response_cache.set( - key, - repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - ) - return result + return self.remote_response_cache.wrap( + key, + repl_layer.get_public_rooms, + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) class RoomListNextBatch(namedtuple("RoomListNextBatch", ( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1b8dfa8254..0d4a3f4677 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,63 +14,161 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import abc import logging +from six.moves import http_client + from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json -from twisted.internet import defer from unpaddedbase64 import decode_base64 +from twisted.internet import defer + +import synapse.server import synapse.types -from synapse.api.constants import ( - EventTypes, Membership, -) -from synapse.api.errors import AuthError, SynapseError, Codes -from synapse.types import UserID, RoomID +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.types import RoomID, UserID from synapse.util.async import Linearizer -from synapse.util.distributor import user_left_room, user_joined_room -from ._base import BaseHandler +from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) id_server_scheme = "https://" -class RoomMemberHandler(BaseHandler): +class RoomMemberHandler(object): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. + __metaclass__ = abc.ABCMeta + def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) + """ + + Args: + hs (synapse.server.HomeServer): + """ + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() + self.config = hs.config + self.simple_http_client = hs.get_simple_http_client() + + self.federation_handler = hs.get_handlers().federation_handler + self.directory_handler = hs.get_handlers().directory_handler + self.registration_handler = hs.get_handlers().registration_handler + self.profile_handler = hs.get_profile_handler() + self.event_creation_hander = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() + self.spam_checker = hs.get_spam_checker() + self._server_notices_mxid = self.config.server_notices_mxid - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") + @abc.abstractmethod + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Try and join a room that this server is not in + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers that can be used + to join via. + room_id (str): Room that we are trying to join + user (UserID): User who is trying to join + content (dict): A dict that should be used as the content of the + join event. + + Returns: + Deferred + """ + raise NotImplementedError() + + @abc.abstractmethod + def _remote_reject_invite(self, remote_room_hosts, room_id, target): + """Attempt to reject an invite for a room this server is not in. If we + fail to do so we locally mark the invite as rejected. + + Args: + requester (Requester) + remote_room_hosts (list[str]): List of servers to use to try and + reject invite + room_id (str) + target (UserID): The user rejecting the invite + + Returns: + Deferred[dict]: A dictionary to be returned to the client, may + include event_id etc, or nothing if we locally rejected + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_joined_room(self, target, room_id): + """Notifies distributor on master process that the user has joined the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def _user_left_room(self, target, room_id): + """Notifies distributor on master process that the user has left the + room. + + Args: + target (UserID) + room_id (str) + + Returns: + Deferred|None + """ + raise NotImplementedError() @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, - prev_event_ids, + prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, ): if content is None: content = {} - msg_handler = self.hs.get_handlers().message_handler content["membership"] = membership if requester.is_guest: content["kind"] = "guest" - event, context = yield msg_handler.create_event( + event, context = yield self.event_creation_hander.create_event( requester, { "type": EventTypes.Member, @@ -83,16 +182,18 @@ class RoomMemberHandler(BaseHandler): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. - duplicate = yield msg_handler.deduplicate_state_event(event, context) + duplicate = yield self.event_creation_hander.deduplicate_state_event( + event, context, + ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) - yield msg_handler.handle_new_client_event( + yield self.event_creation_hander.handle_new_client_event( requester, event, context, @@ -100,7 +201,9 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -114,33 +217,16 @@ class RoomMemberHandler(BaseHandler): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target, room_id) + yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + yield self._user_left_room(target, room_id) defer.returnValue(event) @defer.inlineCallbacks - def remote_join(self, remote_room_hosts, room_id, user, content): - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield self.hs.get_handlers().federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, - ) - yield user_joined_room(self.distributor, user, room_id) - - @defer.inlineCallbacks def update_membership( self, requester, @@ -186,14 +272,19 @@ class RoomMemberHandler(BaseHandler): content_specified = bool(content) if content is None: content = {} + else: + # We do a copy here as we potentially change some keys + # later on. + content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" + # if this is a join with a 3pid signature, we may need to turn a 3pid + # invite into a normal invite before we can handle the join. if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( + yield self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -208,7 +299,51 @@ class RoomMemberHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + if effective_membership_state == Membership.INVITE: + # block any attempts to invite the server notices mxid + if target.to_string() == self._server_notices_mxid: + raise SynapseError( + http_client.FORBIDDEN, + "Cannot invite this user", + ) + + block_invite = False + + if (self._server_notices_mxid is not None and + requester.user.to_string() == self._server_notices_mxid): + # allow the server notices mxid to send invites + is_requester_admin = True + + else: + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + + if not is_requester_admin: + if self.config.block_non_admin_invites: + logger.info( + "Blocking invite: user is not admin and non-admin " + "invites disabled" + ) + block_invite = True + + if not self.spam_checker.user_may_invite( + requester.user.to_string(), target.to_string(), room_id, + ): + logger.info("Blocking invite due to spam checker") + block_invite = True + + if block_invite: + raise SynapseError( + 403, "Invites have been disabled on this server", + ) + + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -238,6 +373,20 @@ class RoomMemberHandler(BaseHandler): if same_sender and same_membership and same_content: defer.returnValue(old_state) + # we don't allow people to reject invites to the server notice + # room, but they can leave it once they are joined. + if ( + old_membership == Membership.INVITE and + effective_membership_state == Membership.LEAVE + ): + is_blocked = yield self._is_server_notice_room(room_id) + if is_blocked: + raise SynapseError( + http_client.FORBIDDEN, + "You cannot reject this invite", + errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, + ) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: @@ -249,13 +398,13 @@ class RoomMemberHandler(BaseHandler): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) content["membership"] = Membership.JOIN - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler if not content_specified: content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) @@ -263,15 +412,15 @@ class RoomMemberHandler(BaseHandler): if requester.is_guest: content["kind"] = "guest" - ret = yield self.remote_join( - remote_room_hosts, room_id, target, content + ret = yield self._remote_join( + requester, remote_room_hosts, room_id, target, content ) defer.returnValue(ret) elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -285,28 +434,10 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - fed_handler = self.hs.get_handlers().federation_handler - try: - ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), - ) - defer.returnValue(ret) - except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. - # - # The 'except' clause is very broad, but we need to - # capture everything from DNS failures upwards - # - logger.warn("Failed to reject invite: %s", e) - - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) - - defer.returnValue({}) + res = yield self._remote_reject_invite( + requester, remote_room_hosts, room_id, target, + ) + defer.returnValue(res) res = yield self._local_membership_update( requester=requester, @@ -315,7 +446,7 @@ class RoomMemberHandler(BaseHandler): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_event_ids=latest_event_ids, + prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) @@ -361,14 +492,16 @@ class RoomMemberHandler(BaseHandler): else: requester = synapse.types.create_requester(target_user) - message_handler = self.hs.get_handlers().message_handler - prev_event = yield message_handler.deduplicate_state_event(event, context) + prev_event = yield self.event_creation_hander.deduplicate_state_event( + event, context, + ) if prev_event is not None: return + prev_state_ids = yield context.get_prev_state_ids(self.store) if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.prev_state_ids) + guest_can_join = yield self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -379,7 +512,7 @@ class RoomMemberHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield message_handler.handle_new_client_event( + yield self.event_creation_hander.handle_new_client_event( requester, event, context, @@ -387,7 +520,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, event.state_key), None ) @@ -401,12 +534,12 @@ class RoomMemberHandler(BaseHandler): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield user_joined_room(self.distributor, target_user, room_id) + yield self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + yield self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -440,7 +573,7 @@ class RoomMemberHandler(BaseHandler): Raises: SynapseError if room alias could not be found. """ - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: @@ -452,7 +585,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, @@ -471,6 +604,16 @@ class RoomMemberHandler(BaseHandler): requester, txn_id ): + if self.config.block_non_admin_invites: + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + Codes.FORBIDDEN, + ) + invitee = yield self._lookup_3pid( id_server, medium, address ) @@ -508,7 +651,7 @@ class RoomMemberHandler(BaseHandler): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = yield self.hs.get_simple_http_client().get_json( + data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), { "medium": medium, @@ -519,7 +662,7 @@ class RoomMemberHandler(BaseHandler): if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) + yield self._verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: @@ -527,11 +670,11 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): + def _verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( + key_data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_scheme, server_hostname, key_name,), ) @@ -556,7 +699,7 @@ class RoomMemberHandler(BaseHandler): user, txn_id ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) + room_state = yield self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -591,6 +734,7 @@ class RoomMemberHandler(BaseHandler): token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( + requester=requester, id_server=id_server, medium=medium, address=address, @@ -605,8 +749,7 @@ class RoomMemberHandler(BaseHandler): ) ) - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_hander.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -628,6 +771,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( self, + requester, id_server, medium, address, @@ -644,6 +788,7 @@ class RoomMemberHandler(BaseHandler): Asks an identity server for a third party invite. Args: + requester (Requester) id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. @@ -685,24 +830,20 @@ class RoomMemberHandler(BaseHandler): "sender_avatar_url": inviter_avatar_url, } - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( + if self.config.invite_3pid_guest: + guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( + requester=requester, medium=medium, address=address, inviter_user_id=inviter_user_id, ) - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - invite_config.update({ "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), + "guest_user_id": guest_user_id, }) - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + data = yield self.simple_http_client.post_urlencoded_get_json( is_url, invite_config ) @@ -725,27 +866,6 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks - def forget(self, user, room_id): - user_id = user.to_string() - - member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN - ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) - - if membership: - yield self.store.forget(user_id, room_id) - - @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): # Have we just created the room, and is this about to be the very # first member event? @@ -766,3 +886,109 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(True) defer.returnValue(False) + + @defer.inlineCallbacks + def _is_server_notice_room(self, room_id): + if self._server_notices_mxid is None: + defer.returnValue(False) + user_ids = yield self.store.get_users_in_room(room_id) + defer.returnValue(self._server_notices_mxid in user_ids) + + +class RoomMemberMasterHandler(RoomMemberHandler): + def __init__(self, hs): + super(RoomMemberMasterHandler, self).__init__(hs) + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Implements RoomMemberHandler._remote_join + """ + # filter ourselves out of remote_room_hosts: do_invite_join ignores it + # and if it is the only entry we'd like to return a 404 rather than a + # 500. + + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield self._user_joined_room(user, room_id) + + @defer.inlineCallbacks + def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + """Implements RoomMemberHandler._remote_reject_invite + """ + fed_handler = self.federation_handler + try: + ret = yield fed_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + target.to_string(), + ) + defer.returnValue(ret) + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + defer.returnValue({}) + + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Implements RoomMemberHandler.get_or_register_3pid_guest + """ + rg = self.registration_handler + return rg.get_or_register_3pid_guest(medium, address, inviter_user_id) + + def _user_joined_room(self, target, room_id): + """Implements RoomMemberHandler._user_joined_room + """ + return user_joined_room(self.distributor, target, room_id) + + def _user_left_room(self, target, room_id): + """Implements RoomMemberHandler._user_left_room + """ + return user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership not in [ + Membership.LEAVE, Membership.BAN + ]: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py new file mode 100644 index 0000000000..22d8b4b0d3 --- /dev/null +++ b/synapse/handlers/room_member_worker.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.handlers.room_member import RoomMemberHandler +from synapse.replication.http.membership import ( + get_or_register_3pid_guest, + notify_user_membership_change, + remote_join, + remote_reject_invite, +) + +logger = logging.getLogger(__name__) + + +class RoomMemberWorkerHandler(RoomMemberHandler): + @defer.inlineCallbacks + def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + """Implements RoomMemberHandler._remote_join + """ + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + ret = yield remote_join( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user_id=user.to_string(), + content=content, + ) + + yield self._user_joined_room(user, room_id) + + defer.returnValue(ret) + + def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + """Implements RoomMemberHandler._remote_reject_invite + """ + return remote_reject_invite( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user_id=target.to_string(), + ) + + def _user_joined_room(self, target, room_id): + """Implements RoomMemberHandler._user_joined_room + """ + return notify_user_membership_change( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + user_id=target.to_string(), + room_id=room_id, + change="joined", + ) + + def _user_left_room(self, target, room_id): + """Implements RoomMemberHandler._user_left_room + """ + return notify_user_membership_change( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + user_id=target.to_string(), + room_id=room_id, + change="left", + ) + + def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): + """Implements RoomMemberHandler.get_or_register_3pid_guest + """ + return get_or_register_3pid_guest( + self.simple_http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index df75d70fac..69ae9731d5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -13,21 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import itertools +import logging -from ._base import BaseHandler +from unpaddedbase64 import decode_base64, encode_base64 -from synapse.api.constants import Membership, EventTypes -from synapse.api.filtering import Filter +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.filtering import Filter from synapse.events.utils import serialize_event from synapse.visibility import filter_events_for_client -from unpaddedbase64 import decode_base64, encode_base64 - -import itertools -import logging - +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -61,9 +60,16 @@ class SearchHandler(BaseHandler): assert batch_group is not None assert batch_group_key is not None assert batch_token is not None - except: + except Exception: raise SynapseError(400, "Invalid batch") + logger.info( + "Search batch properties: %r, %r, %r", + batch_group, batch_group_key, batch_token, + ) + + logger.info("Search content: %s", content) + try: room_cat = content["search_categories"]["room_events"] @@ -271,6 +277,8 @@ class SearchHandler(BaseHandler): # We should never get here due to the guard earlier. raise NotImplementedError() + logger.info("Found %d events to return", len(allowed_events)) + # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: @@ -282,6 +290,11 @@ class SearchHandler(BaseHandler): event.room_id, event.event_id, before_limit, after_limit ) + logger.info( + "Context for search returned %d and %d events", + len(res["events_before"]), len(res["events_after"]), + ) + res["events_before"] = yield filter_events_for_client( self.store, user.to_string(), res["events_before"] ) @@ -348,7 +361,7 @@ class SearchHandler(BaseHandler): rooms = set(e.room_id for e in allowed_events) for room_id in rooms: state = yield self.state_handler.get_current_state(room_id) - state_results[room_id] = state.values() + state_results[room_id] = list(state.values()) state_results.values() diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py new file mode 100644 index 0000000000..7ecdede4dc --- /dev/null +++ b/synapse/handlers/set_password.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, StoreError, SynapseError + +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class SetPasswordHandler(BaseHandler): + """Handler which deals with changing user account passwords""" + def __init__(self, hs): + super(SetPasswordHandler, self).__init__(hs) + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword, requester=None): + password_hash = yield self._auth_handler.hash(newpassword) + + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e + + # we want to log out all of the user's other sessions. First delete + # all his other devices. + yield self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id, + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + yield self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 91c6c6be3c..c24e35362a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import Membership, EventTypes +import collections +import itertools +import logging + +from six import iteritems, itervalues + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.push.clientformat import format_push_rules_for_user +from synapse.types import RoomStreamToken from synapse.util.async import concurrently_execute +from synapse.util.caches.response_cache import ResponseCache from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure, measure_func -from synapse.util.caches.response_cache import ResponseCache -from synapse.push.clientformat import format_push_rules_for_user from synapse.visibility import filter_events_for_client -from synapse.types import RoomStreamToken - -from twisted.internet import defer - -import collections -import logging -import itertools logger = logging.getLogger(__name__) @@ -52,6 +54,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ @@ -76,6 +79,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ @@ -95,6 +99,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ or self.state or self.account_data ) + __bool__ = __nonzero__ # python3 class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ @@ -106,6 +111,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 + + +class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ + "join", + "invite", + "leave", +])): + __slots__ = [] + + def __nonzero__(self): + return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 + + +class DeviceLists(collections.namedtuple("DeviceLists", [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track +])): + __slots__ = [] + + def __nonzero__(self): + return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 class SyncResult(collections.namedtuple("SyncResult", [ @@ -116,9 +145,10 @@ class SyncResult(collections.namedtuple("SyncResult", [ "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have chanegd + "device_lists", # List of user_ids whose devices have changed "device_one_time_keys_count", # Dict of algorithm to count for one time keys # for this device + "groups", ])): __slots__ = [] @@ -134,8 +164,10 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.archived or self.account_data or self.to_device or - self.device_lists + self.device_lists or + self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): @@ -146,7 +178,7 @@ class SyncHandler(object): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs) + self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, @@ -157,15 +189,11 @@ class SyncHandler(object): Returns: A Deferred SyncResult. """ - result = self.response_cache.get(sync_config.request_key) - if not result: - result = self.response_cache.set( - sync_config.request_key, - self._wait_for_sync_for_user( - sync_config, since_token, timeout, full_state - ) - ) - return result + return self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, + ) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, @@ -212,10 +240,10 @@ class SyncHandler(object): defer.returnValue(rules) @defer.inlineCallbacks - def ephemeral_by_room(self, sync_config, now_token, since_token=None): + def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: - sync_config (SyncConfig): The flags, filters and user for the sync. + sync_result_builder(SyncResultBuilder) now_token (StreamToken): Where the server is currently up to. since_token (StreamToken): Where the server was when the client last synced. @@ -225,10 +253,12 @@ class SyncHandler(object): typing events for that room. """ + sync_config = sync_result_builder.sync_config + with Measure(self.clock, "ephemeral_by_room"): typing_key = since_token.typing_key if since_token else "0" - room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] typing, typing_key = yield typing_source.get_new_events( @@ -247,7 +277,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in event.iteritems() + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) @@ -266,7 +296,7 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in event.iteritems() + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) @@ -290,10 +320,20 @@ class SyncHandler(object): if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(itervalues(current_state_ids)) + recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), recents, + always_include_ids=current_state_ids, ) else: recents = [] @@ -316,19 +356,41 @@ class SyncHandler(object): since_key = since_token.room_key while limited and len(recents) < timeline_limit and max_repeat: - events, end_key = yield self.store.get_room_events_stream_for_room( - room_id, - limit=load_limit + 1, - from_key=since_key, - to_key=end_key, - ) + # If we have a since_key then we are trying to get any events + # that have happened since `since_key` up to `end_key`, so we + # can just use `get_room_events_stream_for_room`. + # Otherwise, we want to return the last N events in the room + # in toplogical ordering. + if since_key: + events, end_key = yield self.store.get_room_events_stream_for_room( + room_id, + limit=load_limit + 1, + from_key=since_key, + to_key=end_key, + ) + else: + events, end_key = yield self.store.get_recent_events_for_room( + room_id, + limit=load_limit + 1, + end_token=end_key, + ) loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in loaded_recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(itervalues(current_state_ids)) + loaded_recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), loaded_recents, + always_include_ids=current_state_ids, ) loaded_recents.extend(recents) recents = loaded_recents @@ -381,7 +443,11 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - last_events, token = yield self.store.get_recent_events_for_room( + # FIXME this claims to get the state at a stream position, but + # get_recent_events_for_room operates by topo ordering. This therefore + # does not reliably give you the state at the given stream position. + # (https://github.com/matrix-org/synapse/issues/3305) + last_events, _ = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, ) @@ -475,11 +541,11 @@ class SyncHandler(object): state = {} if state_ids: - state = yield self.store.get_events(state_ids.values()) + state = yield self.store.get_events(list(state_ids.values())) defer.returnValue({ (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(state.values()) + for e in sync_config.filter_collection.filter_room_state(list(state.values())) }) @defer.inlineCallbacks @@ -522,10 +588,22 @@ class SyncHandler(object): # Always use the `now_token` in `SyncResultBuilder` now_token = yield self.event_sources.get_current_token() + user_id = sync_config.user.to_string() + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() + else: + joined_room_ids = yield self.get_rooms_for_user_at( + user_id, now_token.room_stream_id, + ) + sync_result_builder = SyncResultBuilder( sync_config, full_state, since_token=since_token, now_token=now_token, + joined_room_ids=joined_room_ids, ) account_data_by_room = yield self._generate_sync_entry_for_account_data( @@ -535,7 +613,8 @@ class SyncHandler(object): res = yield self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_users = res + newly_joined_rooms, newly_joined_users, _, _ = res + _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( since_token is None and @@ -549,17 +628,22 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) device_lists = yield self._generate_sync_entry_for_device_list( - sync_result_builder + sync_result_builder, + newly_joined_rooms=newly_joined_rooms, + newly_joined_users=newly_joined_users, + newly_left_rooms=newly_left_rooms, + newly_left_users=newly_left_users, ) device_id = sync_config.device_id one_time_key_counts = {} if device_id: - user_id = sync_config.user.to_string() one_time_key_counts = yield self.store.count_e2e_one_time_keys( user_id, device_id ) + yield self._generate_sync_entry_for_groups(sync_result_builder) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -568,31 +652,103 @@ class SyncHandler(object): archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, + groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, next_batch=sync_result_builder.now_token, )) + @measure_func("_generate_sync_entry_for_groups") + @defer.inlineCallbacks + def _generate_sync_entry_for_groups(self, sync_result_builder): + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + + if since_token and since_token.groups_key: + results = yield self.store.get_groups_changes_for_user( + user_id, since_token.groups_key, now_token.groups_key, + ) + else: + results = yield self.store.get_all_groups_for_user( + user_id, now_token.groups_key, + ) + + invited = {} + joined = {} + left = {} + for result in results: + membership = result["membership"] + group_id = result["group_id"] + gtype = result["type"] + content = result["content"] + + if membership == "join": + if gtype == "membership": + # TODO: Add profile + content.pop("membership", None) + joined[group_id] = content["content"] + else: + joined.setdefault(group_id, {})[gtype] = content + elif membership == "invite": + if gtype == "membership": + content.pop("membership", None) + invited[group_id] = content["content"] + else: + if gtype == "membership": + left[group_id] = content["content"] + + sync_result_builder.groups = GroupsSyncResult( + join=joined, + invite=invited, + leave=left, + ) + @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder): + def _generate_sync_entry_for_device_list(self, sync_result_builder, + newly_joined_rooms, newly_joined_users, + newly_left_rooms, newly_left_users): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token if since_token and since_token.device_list_key: - room_ids = yield self.store.get_rooms_for_user(user_id) - - user_ids_changed = set() changed = yield self.store.get_user_whose_devices_changed( since_token.device_list_key ) - for other_user_id in changed: - other_room_ids = yield self.store.get_rooms_for_user(other_user_id) - if room_ids.intersection(other_room_ids): - user_ids_changed.add(other_user_id) - defer.returnValue(user_ids_changed) + # TODO: Be more clever than this, i.e. remove users who we already + # share a room with? + for room_id in newly_joined_rooms: + joined_users = yield self.state.get_current_user_in_room(room_id) + newly_joined_users.update(joined_users) + + for room_id in newly_left_rooms: + left_users = yield self.state.get_current_user_in_room(room_id) + newly_left_users.update(left_users) + + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + changed.update(newly_joined_users) + + if not changed and not newly_left_users: + defer.returnValue(DeviceLists( + changed=[], + left=newly_left_users, + )) + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + defer.returnValue(DeviceLists( + changed=users_who_share_room & changed, + left=set(newly_left_users) - users_who_share_room, + )) else: - defer.returnValue([]) + defer.returnValue(DeviceLists( + changed=[], + left=[], + )) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -738,7 +894,7 @@ class SyncHandler(object): presence.extend(states) # Deduplicate the presence entries so that there's at most one per user - presence = {p.user_id: p for p in presence}.values() + presence = list({p.user_id: p for p in presence}.values()) presence = sync_config.filter_collection.filter_presence( presence @@ -756,8 +912,8 @@ class SyncHandler(object): account_data_by_room(dict): Dictionary of per room account data Returns: - Deferred(tuple): Returns a 2-tuple of - `(newly_joined_rooms, newly_joined_users)` + Deferred(tuple): Returns a 4-tuple of + `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)` """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( @@ -769,7 +925,7 @@ class SyncHandler(object): ephemeral_by_room = {} else: now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, + sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, ) @@ -788,7 +944,7 @@ class SyncHandler(object): ) if not tags_by_room: logger.debug("no-oping sync") - defer.returnValue(([], [])) + defer.returnValue(([], [], [], [])) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, @@ -801,7 +957,7 @@ class SyncHandler(object): if since_token: res = yield self._get_rooms_changed(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms = res + room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( user_id, since_token.account_data_key, @@ -809,6 +965,7 @@ class SyncHandler(object): else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res + newly_left_rooms = [] tags_by_room = yield self.store.get_tags_for_user(user_id) @@ -829,17 +986,30 @@ class SyncHandler(object): # Now we want to get any newly joined users newly_joined_users = set() + newly_left_users = set() if since_token: for joined_sync in sync_result_builder.joined: it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() + joined_sync.timeline.events, itervalues(joined_sync.state) ) for event in it: if event.type == EventTypes.Member: if event.membership == Membership.JOIN: newly_joined_users.add(event.state_key) - - defer.returnValue((newly_joined_rooms, newly_joined_users)) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_users + + defer.returnValue(( + newly_joined_rooms, + newly_joined_users, + newly_left_rooms, + newly_left_users, + )) @defer.inlineCallbacks def _have_rooms_changed(self, sync_result_builder): @@ -860,15 +1030,8 @@ class SyncHandler(object): if rooms_changed: defer.returnValue(True) - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - joined_room_ids = yield self.store.get_rooms_for_user(user_id) - stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream - for room_id in joined_room_ids: + for room_id in sync_result_builder.joined_room_ids: if self.store.has_room_changed_since(room_id, stream_id): defer.returnValue(True) defer.returnValue(False) @@ -883,7 +1046,13 @@ class SyncHandler(object): Returns: Deferred(tuple): Returns a tuple of the form: - `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` + `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)` + + where: + room_entries is a list [RoomSyncResultBuilder] + invited_rooms is a list [InvitedSyncResult] + newly_joined rooms is a list[str] of room ids + newly_left_rooms is a list[str] of room ids """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -892,13 +1061,6 @@ class SyncHandler(object): assert since_token - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - joined_room_ids = yield self.store.get_rooms_for_user(user_id) - # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key @@ -909,16 +1071,29 @@ class SyncHandler(object): mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) newly_joined_rooms = [] + newly_left_rooms = [] room_entries = [] invited = [] - for room_id, events in mem_change_events_by_room_id.items(): + for room_id, events in iteritems(mem_change_events_by_room_id): non_joins = [e for e in events if e.membership != Membership.JOIN] has_join = len(non_joins) != len(events) # We want to figure out if we joined the room at some point since # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: + + old_state_ids = None + if room_id in sync_result_builder.joined_room_ids and non_joins: + # Always include if the user (re)joined the room, especially + # important so that device list changes are calculated correctly. + # If there are non join member events, but we are still in the room, + # then the user must have left and joined + newly_joined_rooms.append(room_id) + + # User is in the room so we don't need to do the invite/leave checks + continue + + if room_id in sync_result_builder.joined_room_ids or has_join: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None @@ -929,12 +1104,33 @@ class SyncHandler(object): if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) - if room_id in joined_room_ids: - continue + # If user is in the room then we don't need to do the invite/leave checks + if room_id in sync_result_builder.joined_room_ids: + continue if not non_joins: continue + # Check if we have left the room. This can either be because we were + # joined before *or* that we since joined and then left. + if events[-1].membership != Membership.JOIN: + if has_join: + newly_left_rooms.append(room_id) + else: + if not old_state_ids: + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get( + (EventTypes.Member, user_id), + None, + ) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) + if old_mem_ev and old_mem_ev.membership == Membership.JOIN: + newly_left_rooms.append(room_id) + # Only bother if we're still currently invited should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: @@ -976,7 +1172,7 @@ class SyncHandler(object): # Get all events for rooms we're currently joined to. room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, + room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, limit=timeline_limit + 1, @@ -984,7 +1180,7 @@ class SyncHandler(object): # We loop through all room ids, even if there are no new events, in case # there are non room events taht we need to notify about. - for room_id in joined_room_ids: + for room_id in sync_result_builder.joined_room_ids: room_entry = room_to_events.get(room_id, None) if room_entry: @@ -1012,7 +1208,7 @@ class SyncHandler(object): upto_token=since_token, )) - defer.returnValue((room_entries, invited, newly_joined_rooms)) + defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms)) @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builder, ignored_users): @@ -1192,6 +1388,54 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) + @defer.inlineCallbacks + def get_rooms_for_user_at(self, user_id, stream_ordering): + """Get set of joined rooms for a user at the given stream ordering. + + The stream ordering *must* be recent, otherwise this may throw an + exception if older than a month. (This function is called with the + current token, which should be perfectly fine). + + Args: + user_id (str) + stream_ordering (int) + + ReturnValue: + Deferred[frozenset[str]]: Set of room_ids the user is in at given + stream_ordering. + """ + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( + user_id, + ) + + joined_room_ids = set() + + # We need to check that the stream ordering of the join for each room + # is before the stream_ordering asked for. This might not be the case + # if the user joins a room between us getting the current token and + # calling `get_rooms_for_user_with_stream_ordering`. + # If the membership's stream ordering is after the given stream + # ordering, we need to go and work out if the user was in the room + # before. + for room_id, membership_stream_ordering in joined_rooms: + if membership_stream_ordering <= stream_ordering: + joined_room_ids.add(room_id) + continue + + logger.info("User joined room after current token: %s", room_id) + + extrems = yield self.store.get_forward_extremeties_for_room( + room_id, stream_ordering, + ) + users_in_room = yield self.state.get_current_user_in_room( + room_id, extrems, + ) + if user_id in users_in_room: + joined_room_ids.add(room_id) + + joined_room_ids = frozenset(joined_room_ids) + defer.returnValue(joined_room_ids) + def _action_has_highlight(actions): for action in actions: @@ -1241,7 +1485,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): class SyncResultBuilder(object): "Used to help build up a new SyncResult for a user" - def __init__(self, sync_config, full_state, since_token, now_token): + def __init__(self, sync_config, full_state, since_token, now_token, + joined_room_ids): """ Args: sync_config(SyncConfig) @@ -1253,6 +1498,7 @@ class SyncResultBuilder(object): self.full_state = full_state self.since_token = since_token self.now_token = now_token + self.joined_room_ids = joined_room_ids self.presence = [] self.account_data = [] @@ -1260,6 +1506,8 @@ class SyncResultBuilder(object): self.invited = [] self.archived = [] self.device = [] + self.groups = None + self.to_device = [] class RoomSyncResultBuilder(object): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 82dedbbc99..2d2d3d5a0d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,17 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from collections import namedtuple + from twisted.internet import defer -from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import preserve_fn +from synapse.api.errors import AuthError, SynapseError +from synapse.types import UserID, get_domain_from_id +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -from synapse.types import UserID, get_domain_from_id - -import logging - -from collections import namedtuple logger = logging.getLogger(__name__) @@ -56,7 +55,7 @@ class TypingHandler(object): self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -97,7 +96,8 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - preserve_fn(self._push_remote)( + run_in_background( + self._push_remote, member=member, typing=True ) @@ -196,7 +196,7 @@ class TypingHandler(object): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - preserve_fn(self._push_remote)(member, typing) + run_in_background(self._push_remote, member, typing) self._push_update_local( member=member, @@ -205,28 +205,31 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_remote(self, member, typing): - users = yield self.state.get_current_user_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() + try: + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, - ) + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - for domain in set(get_domain_from_id(u) for u in users): - if domain != self.server_name: - self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") @defer.inlineCallbacks def _recv_edu(self, origin, content): diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 2a49456bfc..37dda64587 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -14,18 +14,20 @@ # limitations under the License. import logging + +from six import iteritems + from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.storage.roommember import ProfileInfo +from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure -from synapse.util.async import sleep - logger = logging.getLogger(__name__) -class UserDirectoyHandler(object): +class UserDirectoryHandler(object): """Handles querying of and keeping updated the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY @@ -41,9 +43,10 @@ class UserDirectoyHandler(object): one public room. """ - INITIAL_SLEEP_MS = 50 - INITIAL_SLEEP_COUNT = 100 - INITIAL_BATCH_SIZE = 100 + INITIAL_ROOM_SLEEP_MS = 50 + INITIAL_ROOM_SLEEP_COUNT = 100 + INITIAL_ROOM_BATCH_SIZE = 100 + INITIAL_USER_SLEEP_MS = 10 def __init__(self, hs): self.store = hs.get_datastore() @@ -53,6 +56,7 @@ class UserDirectoyHandler(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory + self.search_all_users = hs.config.user_directory_search_all_users # When start up for the first time we need to populate the user_directory. # This is a set of user_id's we've inserted already @@ -111,6 +115,22 @@ class UserDirectoyHandler(object): self._is_processing = False @defer.inlineCallbacks + def handle_local_profile_change(self, user_id, profile): + """Called to update index of our local user profiles when they change + irrespective of any rooms the user may be in. + """ + yield self.store.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url, None, + ) + + @defer.inlineCallbacks + def handle_user_deactivated(self, user_id): + """Called when a user ID is deactivated + """ + yield self.store.remove_from_user_dir(user_id) + yield self.store.remove_from_user_in_public_room(user_id) + + @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: @@ -148,16 +168,30 @@ class UserDirectoyHandler(object): room_ids = yield self.store.get_all_rooms() logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) - num_processed_rooms = 1 + num_processed_rooms = 0 for room_id in room_ids: - logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) - yield self._handle_intial_room(room_id) + logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) + yield self._handle_initial_room(room_id) num_processed_rooms += 1 - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) logger.info("Processed all rooms.") + if self.search_all_users: + num_processed_users = 0 + user_ids = yield self.store.get_all_local_users() + logger.info("Doing initial update of user directory. %d users", len(user_ids)) + for user_id in user_ids: + # We add profiles for all users even if they don't match the + # include pattern, just in case we want to change it in future + logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) + yield self._handle_local_user(user_id) + num_processed_users += 1 + yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.) + + logger.info("Processed all users") + self.initially_handled_users = None self.initially_handled_users_in_public = None self.initially_handled_users_share = None @@ -166,7 +200,7 @@ class UserDirectoyHandler(object): yield self.store.update_user_directory_stream_pos(new_pos) @defer.inlineCallbacks - def _handle_intial_room(self, room_id): + def _handle_initial_room(self, room_id): """Called when we initially fill out user_directory one room at a time """ is_in_room = yield self.store.is_host_joined(room_id, self.server_name) @@ -201,8 +235,8 @@ class UserDirectoyHandler(object): to_update = set() count = 0 for user_id in user_ids: - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) if not self.is_mine_id(user_id): count += 1 @@ -216,8 +250,8 @@ class UserDirectoyHandler(object): if user_id == other_user_id: continue - if count % self.INITIAL_SLEEP_COUNT == 0: - yield sleep(self.INITIAL_SLEEP_MS / 1000.) + if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: + yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) count += 1 user_set = (user_id, other_user_id) @@ -237,13 +271,13 @@ class UserDirectoyHandler(object): else: self.initially_handled_users_share_private_room.add(user_set) - if len(to_insert) > self.INITIAL_BATCH_SIZE: + if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.add_users_who_share_room( room_id, not is_public, to_insert, ) to_insert.clear() - if len(to_update) > self.INITIAL_BATCH_SIZE: + if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE: yield self.store.update_users_who_share_room( room_id, not is_public, to_update, ) @@ -377,7 +411,7 @@ class UserDirectoyHandler(object): if change: users_with_profile = yield self.state.get_current_user_in_room(room_id) - for user_id, profile in users_with_profile.iteritems(): + for user_id, profile in iteritems(users_with_profile): yield self._handle_new_user(room_id, user_id, profile) else: users = yield self.store.get_users_in_public_due_to_room(room_id) @@ -385,14 +419,28 @@ class UserDirectoyHandler(object): yield self._handle_remove_user(room_id, user_id) @defer.inlineCallbacks + def _handle_local_user(self, user_id): + """Adds a new local roomless user into the user_directory_search table. + Used to populate up the user index when we have an + user_directory_search_all_users specified. + """ + logger.debug("Adding new local user to dir, %r", user_id) + + profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(None, {user_id: profile}) + + @defer.inlineCallbacks def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: - room_id (str): room_id that user joined or started being public that + room_id (str): room_id that user joined or started being public user_id (str) """ - logger.debug("Adding user to dir, %r", user_id) + logger.debug("Adding new user to dir, %r", user_id) row = yield self.store.get_user_in_directory(user_id) if not row: @@ -407,7 +455,7 @@ class UserDirectoyHandler(object): if not row: yield self.store.add_users_to_public_room(room_id, [user_id]) else: - logger.debug("Not adding user to public dir, %r", user_id) + logger.debug("Not adding new user to public dir, %r", user_id) # Now we update users who share rooms with users. We do this by getting # all the current users in the room and seeing which aren't already diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index bfebb0f644..58ef8d3ce4 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +13,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.api.errors import SynapseError + + +class RequestTimedOutError(SynapseError): + """Exception representing timeout of an outbound request""" + def __init__(self): + super(RequestTimedOutError, self).__init__(504, "Timed out") + + +def cancelled_to_request_timed_out_error(value, timeout): + """Turns CancelledErrors into RequestTimedOutErrors. + + For use with async.add_timeout_to_deferred + """ + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise RequestTimedOutError() + return value + + +ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') + + +def redact_uri(uri): + """Strips access tokens from the uri replaces with <redacted>""" + return ACCESS_TOKEN_RE.sub( + br'\1<redacted>\3', + uri + ) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py new file mode 100644 index 0000000000..0e10e3f8f7 --- /dev/null +++ b/synapse/http/additional_resource.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + +from synapse.http.server import wrap_json_request_handler + + +class AdditionalResource(Resource): + """Resource wrapper for additional_resources + + If the user has configured additional_resources, we need to wrap the + handler class with a Resource so that we can map it into the resource tree. + + This class is also where we wrap the request handler with logging, metrics, + and exception handling. + """ + def __init__(self, hs, handler): + """Initialise AdditionalResource + + The ``handler`` should return a deferred which completes when it has + done handling the request. It should write a response with + ``request.write()``, and call ``request.finish()``. + + Args: + hs (synapse.server.HomeServer): homeserver + handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): + function to be called to handle the request. + """ + Resource.__init__(self) + self._handler = handler + + # required by the request_handler wrapper + self.clock = hs.get_clock() + + def render(self, request): + self._async_render(request) + return NOT_DONE_YET + + @wrap_json_request_handler + def _async_render(self, request): + return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9eba046bbf..25b6307884 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,49 +13,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from OpenSSL import SSL -from OpenSSL.SSL import VERIFY_NONE +import logging +import urllib -from synapse.api.errors import ( - CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, -) -from synapse.util.logcontext import preserve_context_over_fn -from synapse.util import logcontext -import synapse.metrics -from synapse.http.endpoint import SpiderEndpoint +from six import StringIO -from canonicaljson import encode_canonical_json +from canonicaljson import encode_canonical_json, json +from prometheus_client import Counter -from twisted.internet import defer, reactor, ssl, protocol, task +from OpenSSL import SSL +from OpenSSL.SSL import VERIFY_NONE +from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.web._newclient import ResponseDone from twisted.web.client import ( - BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, - readBody, PartialDownloadError, + Agent, + BrowserLikeRedirectAgent, + ContentDecoderAgent, + FileBodyProducer as TwistedFileBodyProducer, + GzipDecoder, + HTTPConnectionPool, + PartialDownloadError, + readBody, ) -from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web._newclient import ResponseDone - -from StringIO import StringIO - -import simplejson as json -import logging -import urllib +from synapse.api.errors import ( + CodeMessageException, + Codes, + MatrixCodeMessageException, + SynapseError, +) +from synapse.http import cancelled_to_request_timed_out_error, redact_uri +from synapse.http.endpoint import SpiderEndpoint +from synapse.util.async import add_timeout_to_deferred +from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) - -outgoing_requests_counter = metrics.register_counter( - "requests", - labels=["method"], -) -incoming_responses_counter = metrics.register_counter( - "responses", - labels=["method", "code"], -) +outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) +incoming_responses_counter = Counter("synapse_http_client_responses", "", + ["method", "code"]) class SimpleHttpClient(object): @@ -64,13 +65,23 @@ class SimpleHttpClient(object): """ def __init__(self, hs): self.hs = hs + + pool = HTTPConnectionPool(reactor) + + # the pusher makes lots of concurrent SSL connections to sygnal, and + # tends to do so in batches, so we need to allow the pool to keep lots + # of idle connections around. + pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + pool.cachedConnectionTimeout = 2 * 60 + # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent( reactor, connectTimeout=15, - contextFactory=hs.get_http_client_context_factory() + contextFactory=hs.get_http_client_context_factory(), + pool=pool, ) self.user_agent = hs.version_string self.clock = hs.get_clock() @@ -81,76 +92,103 @@ class SimpleHttpClient(object): def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach # counters to it - outgoing_requests_counter.inc(method) + outgoing_requests_counter.labels(method).inc() - def send_request(): + # log request but strip `access_token` (AS requests for example include this) + logger.info("Sending request %s %s", method, redact_uri(uri)) + + try: request_deferred = self.agent.request( method, uri, *args, **kwargs ) - - return self.clock.time_bound_deferred( - request_deferred, - time_out=60, + add_timeout_to_deferred( + request_deferred, 60, self.hs.get_reactor(), + cancelled_to_request_timed_out_error, ) + response = yield make_deferred_yieldable(request_deferred) - logger.info("Sending request %s %s", method, uri) - - try: - with logcontext.PreserveLoggingContext(): - response = yield send_request() - - incoming_responses_counter.inc(method, response.code) + incoming_responses_counter.labels(method, response.code).inc() logger.info( "Received response to %s %s: %s", - method, uri, response.code + method, redact_uri(uri), response.code ) defer.returnValue(response) except Exception as e: - incoming_responses_counter.inc(method, "ERR") + incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, uri, type(e).__name__, e.message + method, redact_uri(uri), type(e).__name__, e.message ) - raise e + raise @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}): + def post_urlencoded_get_json(self, uri, args={}, headers=None): + """ + Args: + uri (str): + args (dict[str, str|List[str]]): query params + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ + # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + actual_headers = { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json): + def post_json_get_json(self, uri, post_json, headers=None): + """ + + Args: + uri (str): + post_json (object): + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header + + Returns: + Deferred[object]: parsed json + """ json_str = encode_canonical_json(post_json) logger.debug("HTTP POST %s -> %s", json_str, uri) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "POST", uri.encode("ascii"), - headers=Headers({ - b"Content-Type": [b"application/json"], - b"User-Agent": [self.user_agent], - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -160,7 +198,7 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, uri, args={}): + def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -169,6 +207,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -177,13 +217,13 @@ class SimpleHttpClient(object): error message. """ try: - body = yield self.get_raw(uri, args) + body = yield self.get_raw(uri, args, headers=headers) defer.returnValue(json.loads(body)) except CodeMessageException as e: raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}): + def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -193,6 +233,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. @@ -205,17 +247,21 @@ class SimpleHttpClient(object): json_str = encode_canonical_json(json_body) + actual_headers = { + b"Content-Type": [b"application/json"], + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "PUT", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - "Content-Type": ["application/json"] - }), + headers=Headers(actual_headers), bodyProducer=FileBodyProducer(StringIO(json_str)) ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(json.loads(body)) @@ -226,7 +272,7 @@ class SimpleHttpClient(object): raise CodeMessageException(response.code, body) @defer.inlineCallbacks - def get_raw(self, uri, args={}): + def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -235,6 +281,8 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. @@ -246,15 +294,19 @@ class SimpleHttpClient(object): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: defer.returnValue(body) @@ -274,27 +326,33 @@ class SimpleHttpClient(object): # The two should be factored out. @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None): + def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET output_stream (file): File to write the response body to. + headers (dict[str, List[str]]|None): If not None, a map from + header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. """ + actual_headers = { + b"User-Agent": [self.user_agent], + } + if headers: + actual_headers.update(headers) + response = yield self.request( "GET", url.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) + headers=Headers(actual_headers), ) - headers = dict(response.headers.getAllRawHeaders()) + resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in headers and headers['Content-Length'] > max_size: + if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -315,10 +373,9 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield preserve_context_over_fn( - _readBodyToFile, - response, output_stream, max_size - ) + length = yield make_deferred_yieldable(_readBodyToFile( + response, output_stream, max_size, + )) except Exception as e: logger.exception("Failed to download body") raise SynapseError( @@ -327,7 +384,9 @@ class SimpleHttpClient(object): Codes.UNKNOWN, ) - defer.returnValue((length, headers, response.request.absoluteURI, response.code)) + defer.returnValue( + (length, resp_headers, response.request.absoluteURI, response.code), + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. @@ -395,7 +454,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): ) try: - body = yield preserve_context_over_fn(readBody, response) + body = yield make_deferred_yieldable(readBody(response)) defer.returnValue(body) except PartialDownloadError as e: # twisted dislikes google's response, no content length. @@ -446,7 +505,7 @@ class SpiderHttpClient(SimpleHttpClient): reactor, SpiderEndpointFactory(hs) ) - ), [('gzip', GzipDecoder)] + ), [(b'gzip', GzipDecoder)] ) # We could look like Chrome: # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index d8923c9abb..d65daa72bb 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,30 +12,97 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet import defer, reactor -from twisted.internet.error import ConnectError -from twisted.names import client, dns -from twisted.names.error import DNSNameError, DomainError - import collections import logging import random +import re import time +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.internet.error import ConnectError +from twisted.names import client, dns +from twisted.names.error import DNSNameError, DomainError logger = logging.getLogger(__name__) SERVER_CACHE = {} - +# our record of an individual server which can be tried to reach a destination. +# +# "host" is the hostname acquired from the SRV record. Except when there's +# no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" ) +def parse_server_name(server_name): + """Split a server name into host/port parts. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == ']': + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile( + "\\A[0-9a-zA-Z.-]+\\Z", +) + + +def parse_and_validate_server_name(server_name): + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == '[': + if host[-1] != ']': + raise ValueError("Mismatched [...] in server name '%s'" % ( + server_name, + )) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' contains invalid characters" % ( + server_name, + )) + + return host, port + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout (int): connection timeout in seconds """ - domain_port = destination.split(":") - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None + domain, port = parse_server_name(destination) endpoint_kw_args = {} @@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, reactor, "matrix", domain, protocol="tcp", default_port=default_port, endpoint=transport_endpoint, endpoint_kw_args=endpoint_kw_args - )) + ), reactor) else: return _WrappingEndpointFac(transport_endpoint( reactor, domain, port, **endpoint_kw_args - )) + ), reactor) class _WrappingEndpointFac(object): - def __init__(self, endpoint_fac): + def __init__(self, endpoint_fac, reactor): self.endpoint_fac = endpoint_fac + self.reactor = reactor @defer.inlineCallbacks def connect(self, protocolFactory): conn = yield self.endpoint_fac.connect(protocolFactory) - conn = _WrappedConnection(conn) + conn = _WrappedConnection(conn, self.reactor) defer.returnValue(conn) @@ -96,9 +162,10 @@ class _WrappedConnection(object): """ __slots__ = ["conn", "last_request"] - def __init__(self, conn): + def __init__(self, conn, reactor): object.__setattr__(self, "conn", conn) object.__setattr__(self, "last_request", time.time()) + self._reactor = reactor def __getattr__(self, name): return getattr(self.conn, name) @@ -113,10 +180,15 @@ class _WrappedConnection(object): if time.time() - self.last_request >= 2.5 * 60: self.abort() # Abort the underlying TLS connection. The abort() method calls - # loseConnection() on the underlying TLS connection which tries to + # loseConnection() on the TLS connection which tries to # shutdown the connection cleanly. We call abortConnection() - # since that will promptly close the underlying TCP connection. - self.transport.abortConnection() + # since that will promptly close the TLS connection. + # + # In Twisted >18.4; the TLS connection will be None if it has closed + # which will make abortConnection() throw. Check that the TLS connection + # is not None before trying to close it. + if self.transport.getHandle() is not None: + self.transport.abortConnection() def request(self, request): self.last_request = time.time() @@ -124,14 +196,14 @@ class _WrappedConnection(object): # Time this connection out if we haven't send a request in the last # N minutes # TODO: Cancel the previous callLater? - reactor.callLater(3 * 60, self._time_things_out_maybe) + self._reactor.callLater(3 * 60, self._time_things_out_maybe) d = self.conn.request(request) def update_request_time(res): self.last_request = time.time() # TODO: Cancel the previous callLater? - reactor.callLater(3 * 60, self._time_things_out_maybe) + self._reactor.callLater(3 * 60, self._time_things_out_maybe) return res d.addCallback(update_request_time) @@ -219,9 +291,10 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s" % self.service_name + "No server available for %s" % self.service_name ) + # look for all servers with the same priority min_priority = self.servers[0].priority weight_indexes = list( (index, server.weight + 1) @@ -231,11 +304,22 @@ class SRVClientEndpoint(object): total_weight = sum(weight for index, weight in weight_indexes) target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: target_weight -= weight if target_weight <= 0: server = self.servers[index] + # XXX: this looks totally dubious: + # + # (a) we never reuse a server until we have been through + # all of the servers at the same priority, so if the + # weights are A: 100, B:1, we always do ABABAB instead of + # AAAA...AAAB (approximately). + # + # (b) After using all the servers at the lowest priority, + # we move onto the next priority. We should only use the + # second priority if servers at the top priority are + # unreachable. + # del self.servers[index] self.used_servers.append(server) return server @@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t if (len(answers) == 1 and answers[0].type == dns.SRV and answers[0].payload - and answers[0].payload.target == dns.Name('.')): + and answers[0].payload.target == dns.Name(b'.')): raise ConnectError("Service %s unavailable" % service_name) for answer in answers: @@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t continue payload = answer.payload - host = str(payload.target) - srv_ttl = answer.ttl - - try: - answers, _, _ = yield dns_client.lookupAddress(host) - except DNSNameError: - continue - for answer in answers: - if answer.type == dns.A and answer.payload: - ip = answer.payload.dottedQuad() - host_ttl = min(srv_ttl, answer.ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 747a791f83..bf1aa29502 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,48 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.retryutils -from twisted.internet import defer, reactor, protocol -from twisted.internet.error import DNSLookupError -from twisted.web.client import readBody, HTTPConnectionPool, Agent -from twisted.web.http_headers import Headers -from twisted.web._newclient import ResponseDone - -from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.async import sleep -from synapse.util import logcontext -import synapse.metrics - -from canonicaljson import encode_canonical_json - -from synapse.api.errors import ( - SynapseError, Codes, HttpResponseException, -) - -from signedjson.sign import sign_json - import cgi -import simplejson as json import logging import random import sys import urllib -import urlparse +from six import string_types +from six.moves.urllib import parse as urlparse -logger = logging.getLogger(__name__) -outbound_logger = logging.getLogger("synapse.http.outbound") +from canonicaljson import encode_canonical_json, json +from prometheus_client import Counter +from signedjson.sign import sign_json -metrics = synapse.metrics.get_metrics_for(__name__) +from twisted.internet import defer, protocol, reactor +from twisted.internet.error import DNSLookupError +from twisted.web._newclient import ResponseDone +from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.http_headers import Headers -outgoing_requests_counter = metrics.register_counter( - "requests", - labels=["method"], -) -incoming_responses_counter = metrics.register_counter( - "responses", - labels=["method", "code"], +import synapse.metrics +import synapse.util.retryutils +from synapse.api.errors import ( + Codes, + FederationDeniedError, + HttpResponseException, + SynapseError, ) +from synapse.http import cancelled_to_request_timed_out_error +from synapse.http.endpoint import matrix_federation_endpoint +from synapse.util import logcontext +from synapse.util.async import add_timeout_to_deferred +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) +outbound_logger = logging.getLogger("synapse.http.outbound") + +outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", + "", ["method"]) +incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses", + "", ["method", "code"]) MAX_LONG_RETRIES = 10 @@ -123,11 +122,22 @@ class MatrixFederationHttpClient(object): Fails with ``HTTPRequestException``: if we get an HTTP response code >= 300. + Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist + (May also fail with plenty of other Exceptions for things like DNS failures, connection failures, SSL failures.) """ + if ( + self.hs.config.federation_domain_whitelist and + destination not in self.hs.config.federation_domain_whitelist + ): + raise FederationDeniedError(destination) + limiter = yield synapse.util.retryutils.get_retry_limiter( destination, self.clock, @@ -173,21 +183,21 @@ class MatrixFederationHttpClient(object): producer = body_callback(method, http_url_bytes, headers_dict) try: - def send_request(): - request_deferred = self.agent.request( - method, - url_bytes, - Headers(headers_dict), - producer - ) - - return self.clock.time_bound_deferred( - request_deferred, - time_out=timeout / 1000. if timeout else 60, - ) - - with logcontext.PreserveLoggingContext(): - response = yield send_request() + request_deferred = self.agent.request( + method, + url_bytes, + Headers(headers_dict), + producer + ) + add_timeout_to_deferred( + request_deferred, + timeout / 1000. if timeout else 60, + self.hs.get_reactor(), + cancelled_to_request_timed_out_error, + ) + response = yield make_deferred_yieldable( + request_deferred, + ) log_result = "%d %s" % (response.code, response.phrase,) break @@ -204,18 +214,15 @@ class MatrixFederationHttpClient(object): raise logger.warn( - "{%s} Sending request failed to %s: %s %s: %s - %s", + "{%s} Sending request failed to %s: %s %s: %s", txn_id, destination, method, url_bytes, - type(e).__name__, _flatten_response_never_received(e), ) - log_result = "%s - %s" % ( - type(e).__name__, _flatten_response_never_received(e), - ) + log_result = _flatten_response_never_received(e) if retries_left and not timeout: if long_retries: @@ -227,7 +234,7 @@ class MatrixFederationHttpClient(object): delay = min(delay, 2) delay *= random.uniform(0.8, 1.4) - yield sleep(delay) + yield self.clock.sleep(delay) retries_left -= 1 else: raise @@ -253,14 +260,35 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) def sign_request(self, destination, method, url_bytes, headers_dict, - content=None): + content=None, destination_is=None): + """ + Signs a request by adding an Authorization header to headers_dict + Args: + destination (bytes|None): The desination home server of the request. + May be None if the destination is an identity server, in which case + destination_is must be non-None. + method (bytes): The HTTP method of the request + url_bytes (bytes): The URI path of the request + headers_dict (dict): Dictionary of request headers to append to + content (bytes): The body of the request + destination_is (bytes): As 'destination', but if the destination is an + identity server + + Returns: + None + """ request = { "method": method, "uri": url_bytes, "origin": self.server_name, - "destination": destination, } + if destination is not None: + request["destination"] = destination + + if destination_is is not None: + request["destination_is"] = destination_is + if content is not None: request["content"] = content @@ -278,7 +306,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None, + def put_json(self, destination, path, args={}, data={}, + json_data_callback=None, long_retries=False, timeout=None, ignore_backoff=False, backoff_on_404=False): @@ -288,6 +317,7 @@ class MatrixFederationHttpClient(object): destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. + args (dict): query params data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to @@ -311,6 +341,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ if not json_data_callback: @@ -331,6 +364,7 @@ class MatrixFederationHttpClient(object): path, body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + query_bytes=encode_query_args(args), long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, @@ -347,7 +381,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False): + timeout=None, ignore_backoff=False, args={}): """ Sends the specifed json data using POST Args: @@ -362,6 +396,7 @@ class MatrixFederationHttpClient(object): giving up. None indicates no timeout. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + args (dict): query params Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -371,6 +406,9 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ def body_callback(method, url_bytes, headers_dict): @@ -383,6 +421,7 @@ class MatrixFederationHttpClient(object): destination, "POST", path, + query_bytes=encode_query_args(args), body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, @@ -424,16 +463,12 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ logger.debug("get_json args: %s", args) - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, basestring): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) def body_callback(method, url_bytes, headers_dict): @@ -444,7 +479,7 @@ class MatrixFederationHttpClient(object): destination, "GET", path, - query_bytes=query_bytes, + query_bytes=encode_query_args(args), body_callback=body_callback, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, @@ -461,6 +496,55 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks + def delete_json(self, destination, path, long_retries=False, + timeout=None, ignore_backoff=False, args={}): + """Send a DELETE request to the remote expecting some json response + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + long_retries (bool): A boolean that indicates whether we should + retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. + ignore_backoff (bool): true to ignore the historical backoff data and + try the request anyway. + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist + """ + + response = yield self._request( + destination, + "DELETE", + path, + query_bytes=encode_query_args(args), + headers_dict={"Content-Type": ["application/json"]}, + long_retries=long_retries, + timeout=timeout, + ignore_backoff=ignore_backoff, + ) + + if 200 <= response.code < 300: + # We need to update the transactions table to say it was sent? + check_content_type_is_json(response.headers) + + with logcontext.PreserveLoggingContext(): + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, retry_on_dns_fail=True, max_size=None, ignore_backoff=False): @@ -481,11 +565,14 @@ class MatrixFederationHttpClient(object): Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ encoded_args = {} for k, vs in args.items(): - if isinstance(vs, basestring): + if isinstance(vs, string_types): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] @@ -513,7 +600,7 @@ class MatrixFederationHttpClient(object): length = yield _readBodyToFile( response, output_stream, max_size ) - except: + except Exception: logger.exception("Failed to download body") raise @@ -578,12 +665,14 @@ class _JsonProducer(object): def _flatten_response_never_received(e): if hasattr(e, "reasons"): - return ", ".join( + reasons = ", ".join( _flatten_response_never_received(f.value) for f in e.reasons ) + + return "%s:[%s]" % (type(e).__name__, reasons) else: - return "%s: %s" % (type(e).__name__, e.message,) + return repr(e) def check_content_type_is_json(headers): @@ -598,7 +687,7 @@ def check_content_type_is_json(headers): RuntimeError if the """ - c_type = headers.getRawHeaders("Content-Type") + c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: raise RuntimeError( "No Content-Type header" @@ -610,3 +699,15 @@ def check_content_type_is_json(headers): raise RuntimeError( "Content-Type not application/json: was '%s'" % c_type ) + + +def encode_query_args(args): + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, string_types): + vs = [vs] + encoded_args[k] = [v.encode("UTF-8") for v in vs] + + query_bytes = urllib.urlencode(encoded_args, True) + + return query_bytes diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py new file mode 100644 index 0000000000..588e280571 --- /dev/null +++ b/synapse/http/request_metrics.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from prometheus_client.core import Counter, Histogram + +from synapse.metrics import LaterGauge +from synapse.util.logcontext import LoggingContext + +logger = logging.getLogger(__name__) + + +# total number of responses served, split by method/servlet/tag +response_count = Counter( + "synapse_http_server_response_count", "", ["method", "servlet", "tag"] +) + +requests_counter = Counter( + "synapse_http_server_requests_received", "", ["method", "servlet"] +) + +outgoing_responses_counter = Counter( + "synapse_http_server_responses", "", ["method", "code"] +) + +response_timer = Histogram( + "synapse_http_server_response_time_seconds", "sec", + ["method", "servlet", "tag", "code"], +) + +response_ru_utime = Counter( + "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"] +) + +response_ru_stime = Counter( + "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"] +) + +response_db_txn_count = Counter( + "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"] +) + +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request +response_db_txn_duration = Counter( + "synapse_http_server_response_db_txn_duration_seconds", + "", + ["method", "servlet", "tag"], +) + +# seconds spent waiting for a db connection, when processing this request +response_db_sched_duration = Counter( + "synapse_http_server_response_db_sched_duration_seconds", + "", + ["method", "servlet", "tag"], +) + +# size in bytes of the response written +response_size = Counter( + "synapse_http_server_response_size", "", ["method", "servlet", "tag"] +) + +# In flight metrics are incremented while the requests are in flight, rather +# than when the response was written. + +in_flight_requests_ru_utime = Counter( + "synapse_http_server_in_flight_requests_ru_utime_seconds", + "", + ["method", "servlet"], +) + +in_flight_requests_ru_stime = Counter( + "synapse_http_server_in_flight_requests_ru_stime_seconds", + "", + ["method", "servlet"], +) + +in_flight_requests_db_txn_count = Counter( + "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"] +) + +# seconds spent waiting for db txns, excluding scheduling time, when processing +# this request +in_flight_requests_db_txn_duration = Counter( + "synapse_http_server_in_flight_requests_db_txn_duration_seconds", + "", + ["method", "servlet"], +) + +# seconds spent waiting for a db connection, when processing this request +in_flight_requests_db_sched_duration = Counter( + "synapse_http_server_in_flight_requests_db_sched_duration_seconds", + "", + ["method", "servlet"], +) + +# The set of all in flight requests, set[RequestMetrics] +_in_flight_requests = set() + + +def _get_in_flight_counts(): + """Returns a count of all in flight requests by (method, server_name) + + Returns: + dict[tuple[str, str], int] + """ + # Cast to a list to prevent it changing while the Prometheus + # thread is collecting metrics + reqs = list(_in_flight_requests) + + for rm in reqs: + rm.update_metrics() + + # Map from (method, name) -> int, the number of in flight requests of that + # type + counts = {} + for rm in reqs: + key = (rm.method, rm.name,) + counts[key] = counts.get(key, 0) + 1 + + return counts + + +LaterGauge( + "synapse_http_server_in_flight_requests_count", + "", + ["method", "servlet"], + _get_in_flight_counts, +) + + +class RequestMetrics(object): + def start(self, time_sec, name, method): + self.start = time_sec + self.start_context = LoggingContext.current_context() + self.name = name + self.method = method + + # _request_stats records resource usage that we have already added + # to the "in flight" metrics. + self._request_stats = self.start_context.get_resource_usage() + + _in_flight_requests.add(self) + + def stop(self, time_sec, request): + _in_flight_requests.discard(self) + + context = LoggingContext.current_context() + + tag = "" + if context: + tag = context.tag + + if context != self.start_context: + logger.warn( + "Context have unexpectedly changed %r, %r", + context, self.start_context + ) + return + + response_code = str(request.code) + + outgoing_responses_counter.labels(request.method, response_code).inc() + + response_count.labels(request.method, self.name, tag).inc() + + response_timer.labels(request.method, self.name, tag, response_code).observe( + time_sec - self.start + ) + + resource_usage = context.get_resource_usage() + + response_ru_utime.labels(request.method, self.name, tag).inc( + resource_usage.ru_utime, + ) + response_ru_stime.labels(request.method, self.name, tag).inc( + resource_usage.ru_stime, + ) + response_db_txn_count.labels(request.method, self.name, tag).inc( + resource_usage.db_txn_count + ) + response_db_txn_duration.labels(request.method, self.name, tag).inc( + resource_usage.db_txn_duration_sec + ) + response_db_sched_duration.labels(request.method, self.name, tag).inc( + resource_usage.db_sched_duration_sec + ) + + response_size.labels(request.method, self.name, tag).inc(request.sentLength) + + # We always call this at the end to ensure that we update the metrics + # regardless of whether a call to /metrics while the request was in + # flight. + self.update_metrics() + + def update_metrics(self): + """Updates the in flight metrics with values from this request. + """ + new_stats = self.start_context.get_resource_usage() + + diff = new_stats - self._request_stats + self._request_stats = new_stats + + in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime) + in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime) + + in_flight_requests_db_txn_count.labels(self.method, self.name).inc( + diff.db_txn_count + ) + + in_flight_requests_db_txn_duration.labels(self.method, self.name).inc( + diff.db_txn_duration_sec + ) + + in_flight_requests_db_sched_duration.labels(self.method, self.name).inc( + diff.db_sched_duration_sec + ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 7ef3d526b1..c70fdbdfd2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,148 +13,205 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import cgi +import collections +import logging +import urllib +from six.moves import http_client -from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes -) -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches import intern_dict -from synapse.util.metrics import Measure -import synapse.metrics -import synapse.events - -from canonicaljson import ( - encode_canonical_json, encode_pretty_printed_json -) +from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from twisted.internet import defer -from twisted.web import server, resource +from twisted.python import failure +from twisted.web import resource, server from twisted.web.server import NOT_DONE_YET from twisted.web.util import redirectTo -import collections -import logging -import urllib -import ujson +import synapse.events +import synapse.metrics +from synapse.api.errors import ( + CodeMessageException, + Codes, + SynapseError, + UnrecognizedRequestError, + cs_exception, +) +from synapse.http.request_metrics import requests_counter +from synapse.util.caches import intern_dict +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) +HTML_ERROR_TEMPLATE = """<!DOCTYPE html> +<html lang=en> + <head> + <meta charset="utf-8"> + <title>Error {code}</title> + </head> + <body> + <p>{msg}</p> + </body> +</html> +""" -incoming_requests_counter = metrics.register_counter( - "requests", - labels=["method", "servlet", "tag"], -) -outgoing_responses_counter = metrics.register_counter( - "responses", - labels=["method", "code"], -) -response_timer = metrics.register_distribution( - "response_time", - labels=["method", "servlet", "tag"] -) +def wrap_json_request_handler(h): + """Wraps a request handler method with exception handling. -response_ru_utime = metrics.register_distribution( - "response_ru_utime", labels=["method", "servlet", "tag"] -) + Also adds logging as per wrap_request_handler_with_logging. -response_ru_stime = metrics.register_distribution( - "response_ru_stime", labels=["method", "servlet", "tag"] -) + The handler method must have a signature of "handle_foo(self, request)", + where "self" must have a "clock" attribute (and "request" must be a + SynapseRequest). -response_db_txn_count = metrics.register_distribution( - "response_db_txn_count", labels=["method", "servlet", "tag"] -) + The handler must return a deferred. If the deferred succeeds we assume that + a response has been sent. If the deferred fails with a SynapseError we use + it to send a JSON response with the appropriate HTTP reponse code. If the + deferred fails with any other type of error we send a 500 reponse. + """ -response_db_txn_duration = metrics.register_distribution( - "response_db_txn_duration", labels=["method", "servlet", "tag"] -) + @defer.inlineCallbacks + def wrapped_request_handler(self, request): + try: + yield h(self, request) + except CodeMessageException as e: + code = e.code + if isinstance(e, SynapseError): + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg + ) + else: + logger.exception(e) + respond_with_json( + request, code, cs_exception(e), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) + except Exception: + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + logger.error( + "Failed handle request via %r: %r: %s", + h, + request, + f.getTraceback().rstrip(), + ) + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) -_next_request_id = 0 + return wrap_request_handler_with_logging(wrapped_request_handler) -def request_handler(include_metrics=False): - """Decorator for ``wrap_request_handler``""" - return lambda request_handler: wrap_request_handler(request_handler, include_metrics) +def wrap_html_request_handler(h): + """Wraps a request handler method with exception handling. + Also adds logging as per wrap_request_handler_with_logging. -def wrap_request_handler(request_handler, include_metrics=False): - """Wraps a method that acts as a request handler with the necessary logging - and exception handling. + The handler method must have a signature of "handle_foo(self, request)", + where "self" must have a "clock" attribute (and "request" must be a + SynapseRequest). + """ + def wrapped_request_handler(self, request): + d = defer.maybeDeferred(h, self, request) + d.addErrback(_return_html_error, request) + return d - The method must have a signature of "handle_foo(self, request)". The - argument "self" must have "version_string" and "clock" attributes. The - argument "request" must be a twisted HTTP request. + return wrap_request_handler_with_logging(wrapped_request_handler) - The method must return a deferred. If the deferred succeeds we assume that - a response has been sent. If the deferred fails with a SynapseError we use - it to send a JSON response with the appropriate HTTP reponse code. If the - deferred fails with any other type of error we send a 500 reponse. - We insert a unique request-id into the logging context for this request and - log the response and duration for this request. +def _return_html_error(f, request): + """Sends an HTML error page corresponding to the given failure + + Args: + f (twisted.python.failure.Failure): + request (twisted.web.iweb.IRequest): """ + if f.check(CodeMessageException): + cme = f.value + code = cme.code + msg = cme.msg + + if isinstance(cme, SynapseError): + logger.info( + "%s SynapseError: %s - %s", request, code, msg + ) + else: + logger.error( + "Failed handle request %r: %s", + request, + f.getTraceback().rstrip(), + ) + else: + code = http_client.INTERNAL_SERVER_ERROR + msg = "Internal server error" + + logger.error( + "Failed handle request %r: %s", + request, + f.getTraceback().rstrip(), + ) + + body = HTML_ERROR_TEMPLATE.format( + code=code, msg=cgi.escape(msg), + ).encode("utf-8") + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%i" % (len(body),)) + request.write(body) + finish_request(request) + +def wrap_request_handler_with_logging(h): + """Wraps a request handler to provide logging and metrics + + The handler method must have a signature of "handle_foo(self, request)", + where "self" must have a "clock" attribute (and "request" must be a + SynapseRequest). + + As well as calling `request.processing` (which will log the response and + duration for this request), the wrapped request handler will insert the + request id into the logging context. + """ @defer.inlineCallbacks def wrapped_request_handler(self, request): - global _next_request_id - request_id = "%s-%s" % (request.method, _next_request_id) - _next_request_id += 1 + """ + Args: + self: + request (synapse.http.site.SynapseRequest): + """ + request_id = request.get_request_id() with LoggingContext(request_id) as request_context: + request_context.request = request_id with Measure(self.clock, "wrapped_request_handler"): - request_metrics = RequestMetrics() - request_metrics.start(self.clock, name=self.__class__.__name__) - - request_context.request = request_id - with request.processing(): - try: - with PreserveLoggingContext(request_context): - if include_metrics: - yield request_handler(self, request, request_metrics) - else: - yield request_handler(self, request) - except CodeMessageException as e: - code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) - else: - logger.exception(e) - outgoing_responses_counter.inc(request.method, str(code)) - respond_with_json( - request, code, cs_exception(e), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - version_string=self.version_string, - ) - except: - logger.exception( - "Failed handle request %s.%s on %r: %r", - request_handler.__module__, - request_handler.__name__, - self, - request - ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True - ) - finally: - try: - request_metrics.stop( - self.clock, request - ) - except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + # we start the request metrics timer here with an initial stab + # at the servlet name. For most requests that name will be + # JsonResource (or a subclass), and JsonResource._async_render + # will update it once it picks a servlet. + servlet_name = self.__class__.__name__ + with request.processing(servlet_name): + with PreserveLoggingContext(request_context): + d = defer.maybeDeferred(h, self, request) + + # record the arrival of the request *after* + # dispatching to the handler, so that the handler + # can update the servlet name in the request + # metrics + requests_counter.labels(request.method, + request.request_metrics.name).inc() + yield d return wrapped_request_handler @@ -183,7 +241,7 @@ class JsonResource(HttpServer, resource.Resource): """ This implements the HttpServer interface and provides JSON support for Resources. - Register callbacks via register_path() + Register callbacks via register_paths() Callbacks can return a tuple of status code and a dict in which case the the dict will automatically be sent to the client as a JSON object. @@ -203,7 +261,6 @@ class JsonResource(HttpServer, resource.Resource): self.canonical_json = canonical_json self.clock = hs.get_clock() self.path_regexs = {} - self.version_string = hs.version_string self.hs = hs def register_paths(self, method, path_patterns, callback): @@ -219,122 +276,103 @@ class JsonResource(HttpServer, resource.Resource): self._async_render(request) return server.NOT_DONE_YET - # Disable metric reporting because _async_render does its own metrics. - # It does its own metric reporting because _async_render dispatches to - # a callback and it's the class name of that callback we want to report - # against rather than the JsonResource itself. - @request_handler(include_metrics=True) + @wrap_json_request_handler @defer.inlineCallbacks - def _async_render(self, request, request_metrics): + def _async_render(self, request): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. """ - if request.method == "OPTIONS": - self._send_response(request, 200, {}) - return + callback, group_dict = self._get_handler_for_request(request) - # Loop through all the registered callbacks to check if the method - # and path regex match - for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) - if not m: - continue + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback + request.request_metrics.name = servlet_classname - # We found a match! Trigger callback and then return the - # returned response. We pass both the request and any - # matched groups from the regex to the callback. + # Now trigger the callback. If it returns a response, we send it + # here. If it throws an exception, that is handled by the wrapper + # installed by @request_handler. - callback = path_entry.callback + kwargs = intern_dict({ + name: urllib.unquote(value).decode("UTF-8") if value else value + for name, value in group_dict.items() + }) - kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value - for name, value in m.groupdict().items() - }) + callback_return = yield callback(request, **kwargs) + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) - callback_return = yield callback(request, **kwargs) - if callback_return is not None: - code, response = callback_return - self._send_response(request, code, response) + def _get_handler_for_request(self, request): + """Finds a callback method to handle the given request - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback + Args: + request (twisted.web.http.Request): - request_metrics.name = servlet_classname + Returns: + Tuple[Callable, dict[str, str]]: callback method, and the dict + mapping keys to path components as specified in the handler's + path match regexp. - return + The callback will normally be a method registered via + register_paths, so will return (possibly via Deferred) either + None, or a tuple of (http code, response body). + """ + if request.method == b"OPTIONS": + return _options_handler, {} + + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request.method, []): + m = path_entry.pattern.match(request.path) + if m: + # We found a match! + return path_entry.callback, m.groupdict() # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - raise UnrecognizedRequestError() + return _unrecognised_request_handler, {} def _send_response(self, request, code, response_json_object, response_code_message=None): - # could alternatively use request.notifyFinish() and flip a flag when - # the Deferred fires, but since the flag is RIGHT THERE it seems like - # a waste. - if request._disconnected: - logger.warn( - "Not sending response to request %s, already disconnected.", - request) - return - - outgoing_responses_counter.inc(request.method, str(code)) - # TODO: Only enable CORS for the requests that need it. respond_with_json( request, code, response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), - version_string=self.version_string, canonical_json=self.canonical_json, ) -class RequestMetrics(object): - def start(self, clock, name): - self.start = clock.time_msec() - self.start_context = LoggingContext.current_context() - self.name = name +def _options_handler(request): + """Request handler for OPTIONS requests - def stop(self, clock, request): - context = LoggingContext.current_context() + This is a request handler suitable for return from + _get_handler_for_request. It returns a 200 and an empty body. - tag = "" - if context: - tag = context.tag + Args: + request (twisted.web.http.Request): - if context != self.start_context: - logger.warn( - "Context have unexpectedly changed %r, %r", - context, self.start_context - ) - return + Returns: + Tuple[int, dict]: http code, response body. + """ + return 200, {} - incoming_requests_counter.inc(request.method, self.name, tag) - response_timer.inc_by( - clock.time_msec() - self.start, request.method, - self.name, tag - ) +def _unrecognised_request_handler(request): + """Request handler for unrecognised requests - ru_utime, ru_stime = context.get_resource_usage() + This is a request handler suitable for return from + _get_handler_for_request. It actually just raises an + UnrecognizedRequestError. - response_ru_utime.inc_by( - ru_utime, request.method, self.name, tag - ) - response_ru_stime.inc_by( - ru_stime, request.method, self.name, tag - ) - response_db_txn_count.inc_by( - context.db_txn_count, request.method, self.name, tag - ) - response_db_txn_duration.inc_by( - context.db_txn_duration, request.method, self.name, tag - ) + Args: + request (twisted.web.http.Request): + """ + raise UnrecognizedRequestError() class RootRedirect(resource.Resource): @@ -355,26 +393,33 @@ class RootRedirect(resource.Resource): def respond_with_json(request, code, json_object, send_cors=False, response_code_message=None, pretty_print=False, - version_string="", canonical_json=True): + canonical_json=True): + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warn( + "Not sending response to request %s, already disconnected.", + request) + return + if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + "\n" else: if canonical_json or synapse.events.USE_FROZEN_DICTS: json_bytes = encode_canonical_json(json_object) else: - # ujson doesn't like frozen_dicts. - json_bytes = ujson.dumps(json_object, ensure_ascii=False) + json_bytes = json.dumps(json_object) return respond_with_json_bytes( request, code, json_bytes, send_cors=send_cors, response_code_message=response_code_message, - version_string=version_string ) def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - version_string="", response_code_message=None): + response_code_message=None): """Sends encoded JSON in response to the given request. Args: @@ -388,8 +433,8 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, request.setResponseCode(code, message=response_code_message) request.setHeader(b"Content-Type", b"application/json") - request.setHeader(b"Server", version_string) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") if send_cors: set_cors_headers(request) @@ -437,9 +482,9 @@ def finish_request(request): def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( - "User-Agent", default=[] + b"User-Agent", default=[] ) for user_agent in user_agents: - if "curl" in user_agent: + if b"curl" in user_agent: return True return False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 9a4c36ad5d..882816dc8f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,10 +15,11 @@ """ This module contains base REST classes for constructing REST servlets. """ -from synapse.api.errors import SynapseError, Codes - import logging -import simplejson + +from canonicaljson import json + +from synapse.api.errors import Codes, SynapseError logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def parse_integer_from_args(args, name, default=None, required=False): if name in args: try: return int(args[name][0]) - except: + except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message) else: @@ -88,7 +89,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): "true": True, "false": False, }[args[name][0]] - except: + except Exception: message = ( "Boolean query parameter %r must be one of" " ['true', 'false']" @@ -148,11 +149,13 @@ def parse_string_from_args(args, name, default=None, required=False, return default -def parse_json_value_from_request(request): +def parse_json_value_from_request(request, allow_empty_body=False): """Parse a JSON value from the body of a twisted HTTP request. Args: request: the twisted HTTP request. + allow_empty_body (bool): if True, an empty body will be accepted and + turned into None Returns: The JSON value. @@ -162,28 +165,39 @@ def parse_json_value_from_request(request): """ try: content_bytes = request.content.read() - except: + except Exception: raise SynapseError(400, "Error reading JSON content.") + if not content_bytes and allow_empty_body: + return None + try: - content = simplejson.loads(content_bytes) - except simplejson.JSONDecodeError: + content = json.loads(content_bytes) + except Exception as e: + logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content -def parse_json_object_from_request(request): +def parse_json_object_from_request(request, allow_empty_body=False): """Parse a JSON object from the body of a twisted HTTP request. Args: request: the twisted HTTP request. + allow_empty_body (bool): if True, an empty body will be accepted and + turned into an empty dict. Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request(request) + content = parse_json_value_from_request( + request, allow_empty_body=allow_empty_body, + ) + + if allow_empty_body and content is None: + return {} if type(content) != dict: message = "Content must be a JSON object." @@ -192,7 +206,7 @@ def parse_json_object_from_request(request): return content -def assert_params_in_request(body, required): +def assert_params_in_dict(body, required): absent = [] for k in required: if k not in body: diff --git a/synapse/http/site.py b/synapse/http/site.py index 4b09d7ee66..5fd30a4c2c 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -12,27 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext -from twisted.web.server import Site, Request - import contextlib import logging -import re import time -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +from twisted.web.server import Request, Site + +from synapse.http import redact_uri +from synapse.http.request_metrics import RequestMetrics +from synapse.util.logcontext import ContextResourceUsage, LoggingContext + +logger = logging.getLogger(__name__) + +_next_request_seq = 0 class SynapseRequest(Request): - def __init__(self, site, *args, **kw): - Request.__init__(self, *args, **kw) + """Class which encapsulates an HTTP request to synapse. + + All of the requests processed in synapse are of this type. + + It extends twisted's twisted.web.server.Request, and adds: + * Unique request ID + * Redaction of access_token query-params in __repr__ + * Logging at start and end + * Metrics to record CPU, wallclock and DB time by endpoint. + + It provides a method `processing` which should be called by the Resource + which is handling the request, and returns a context manager. + + """ + def __init__(self, site, channel, *args, **kw): + Request.__init__(self, channel, *args, **kw) self.site = site + self._channel = channel self.authenticated_entity = None self.start_time = 0 + global _next_request_seq + self.request_seq = _next_request_seq + _next_request_seq += 1 + def __repr__(self): # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % ( + return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( self.__class__.__name__, id(self), self.method, @@ -41,16 +64,27 @@ class SynapseRequest(Request): self.site.site_tag, ) + def get_request_id(self): + return "%s-%i" % (self.method, self.request_seq) + def get_redacted_uri(self): - return ACCESS_TOKEN_RE.sub( - r'\1<redacted>\3', - self.uri - ) + return redact_uri(self.uri) def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + + def render(self, resrc): + # override the Server header which is set by twisted + self.setHeader("Server", self.site.server_version_string) + return Request.render(self, resrc) + + def _started_processing(self, servlet_name): + self.start_time = time.time() + self.request_metrics = RequestMetrics() + self.request_metrics.start( + self.start_time, name=servlet_name, method=self.method, + ) - def started_processing(self): self.site.access_logger.info( "%s - %s - Received request: %s %s", self.getClientIP(), @@ -58,44 +92,85 @@ class SynapseRequest(Request): self.method, self.get_redacted_uri() ) - self.start_time = int(time.time() * 1000) - - def finished_processing(self): + def _finished_processing(self): try: context = LoggingContext.current_context() - ru_utime, ru_stime = context.get_resource_usage() - db_txn_count = context.db_txn_count - db_txn_duration = context.db_txn_duration - except: - ru_utime, ru_stime = (0, 0) - db_txn_count, db_txn_duration = (0, 0) + usage = context.get_resource_usage() + except Exception: + usage = ContextResourceUsage() + + end_time = time.time() + + # need to decode as it could be raw utf-8 bytes + # from a IDN servname in an auth header + authenticated_entity = self.authenticated_entity + if authenticated_entity is not None: + authenticated_entity = authenticated_entity.decode("utf-8", "replace") + + # ...or could be raw utf-8 bytes in the User-Agent header. + # N.B. if you don't do this, the logger explodes cryptically + # with maximum recursion trying to log errors about + # the charset problem. + # c.f. https://github.com/matrix-org/synapse/issues/3471 + user_agent = self.get_user_agent() + if user_agent is not None: + user_agent = user_agent.decode("utf-8", "replace") self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %dms (%dms, %dms) (%dms/%d)" - " %sB %s \"%s %s %s\" \"%s\"", + " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" + " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", self.getClientIP(), self.site.site_tag, - self.authenticated_entity, - int(time.time() * 1000) - self.start_time, - int(ru_utime * 1000), - int(ru_stime * 1000), - int(db_txn_duration * 1000), - int(db_txn_count), + authenticated_entity, + end_time - self.start_time, + usage.ru_utime, + usage.ru_stime, + usage.db_sched_duration_sec, + usage.db_txn_duration_sec, + int(usage.db_txn_count), self.sentLength, self.code, self.method, self.get_redacted_uri(), self.clientproto, - self.get_user_agent(), + user_agent, + usage.evt_db_fetch_count, ) + try: + self.request_metrics.stop(end_time, self) + except Exception as e: + logger.warn("Failed to stop metrics: %r", e) + @contextlib.contextmanager - def processing(self): - self.started_processing() + def processing(self, servlet_name): + """Record the fact that we are processing this request. + + Returns a context manager; the correct way to use this is: + + @defer.inlineCallbacks + def handle_request(request): + with request.processing("FooServlet"): + yield really_handle_the_request() + + This will log the request's arrival. Once the context manager is + closed, the completion of the request will be logged, and the various + metrics will be updated. + + Args: + servlet_name (str): the name of the servlet which will be + processing this request. This is used in the metrics. + + It is possible to update this afterwards by updating + self.request_metrics.servlet_name. + """ + # TODO: we should probably just move this into render() and finish(), + # to save having to call a separate method. + self._started_processing(servlet_name) yield - self.finished_processing() + self._finished_processing() class XForwardedForRequest(SynapseRequest): @@ -133,7 +208,8 @@ class SynapseSite(Site): Subclass of a twisted http Site that does access logging with python's standard logging """ - def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs): + def __init__(self, logger_name, site_tag, config, resource, + server_version_string, *args, **kwargs): Site.__init__(self, resource, *args, **kwargs) self.site_tag = site_tag @@ -141,6 +217,7 @@ class SynapseSite(Site): proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) + self.server_version_string = server_version_string def log(self, request): pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 2265e6e8d6..a9158fc066 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -13,118 +13,198 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import functools -import time import gc +import logging +import os +import platform +import time -from twisted.internet import reactor - -from .metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, - MemoryUsageMetric, -) -from .process_collector import register_process_collector +import attr +from prometheus_client import Counter, Gauge, Histogram +from prometheus_client.core import REGISTRY, GaugeMetricFamily +from twisted.internet import reactor logger = logging.getLogger(__name__) - +running_on_pypy = platform.python_implementation() == "PyPy" all_metrics = [] all_collectors = [] +all_gauges = {} +HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") -class Metrics(object): - """ A single Metrics object gives a (mutable) slice view of the all_metrics - dict, allowing callers to easily register new metrics that are namespaced - nicely.""" - def __init__(self, name): - self.name_prefix = name +class RegistryProxy(object): - def make_subspace(self, name): - return Metrics("%s_%s" % (self.name_prefix, name)) + @staticmethod + def collect(): + for metric in REGISTRY.collect(): + if not metric.name.startswith("__"): + yield metric - def register_collector(self, func): - all_collectors.append(func) - def _register(self, metric_class, name, *args, **kwargs): - full_name = "%s_%s" % (self.name_prefix, name) +@attr.s(hash=True) +class LaterGauge(object): - metric = metric_class(full_name, *args, **kwargs) + name = attr.ib() + desc = attr.ib() + labels = attr.ib(hash=False) + caller = attr.ib() - all_metrics.append(metric) - return metric + def collect(self): - def register_counter(self, *args, **kwargs): - return self._register(CounterMetric, *args, **kwargs) + g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) - def register_callback(self, *args, **kwargs): - return self._register(CallbackMetric, *args, **kwargs) + try: + calls = self.caller() + except Exception: + logger.exception( + "Exception running callback for LaterGauge(%s)", + self.name, + ) + yield g + return - def register_distribution(self, *args, **kwargs): - return self._register(DistributionMetric, *args, **kwargs) + if isinstance(calls, dict): + for k, v in calls.items(): + g.add_metric(k, v) + else: + g.add_metric([], calls) - def register_cache(self, *args, **kwargs): - return self._register(CacheMetric, *args, **kwargs) + yield g + def __attrs_post_init__(self): + self._register() -def register_memory_metrics(hs): - try: - import psutil - process = psutil.Process() - process.memory_info().rss - except (ImportError, AttributeError): - logger.warn( - "psutil is not installed or incorrect version." - " Disabling memory metrics." - ) - return - metric = MemoryUsageMetric(hs, psutil) - all_metrics.append(metric) + def _register(self): + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering" % (self.name,)) + REGISTRY.unregister(all_gauges.pop(self.name)) + REGISTRY.register(self) + all_gauges[self.name] = self -def get_metrics_for(pkg_name): - """ Returns a Metrics instance for conveniently creating metrics - namespaced with the given name prefix. """ - # Convert a "package.name" to "package_name" because Prometheus doesn't - # let us use . in metric names - return Metrics(pkg_name.replace(".", "_")) +# +# Detailed CPU metrics +# +class CPUMetrics(object): -def render_all(): - strs = [] + def __init__(self): + ticks_per_sec = 100 + try: + # Try and get the system config + ticks_per_sec = os.sysconf('SC_CLK_TCK') + except (ValueError, TypeError, AttributeError): + pass - for collector in all_collectors: - collector() + self.ticks_per_sec = ticks_per_sec - for metric in all_metrics: - try: - strs += metric.render() - except Exception: - strs += ["# FAILED to render"] - logger.exception("Failed to render metric") + def collect(self): + if not HAVE_PROC_SELF_STAT: + return - strs.append("") # to generate a final CRLF + with open("/proc/self/stat") as s: + line = s.read() + raw_stats = line.split(") ", 1)[1].split(" ") - return "\n".join(strs) + user = GaugeMetricFamily("process_cpu_user_seconds_total", "") + user.add_metric([], float(raw_stats[11]) / self.ticks_per_sec) + yield user + sys = GaugeMetricFamily("process_cpu_system_seconds_total", "") + sys.add_metric([], float(raw_stats[12]) / self.ticks_per_sec) + yield sys -register_process_collector(get_metrics_for("process")) +REGISTRY.register(CPUMetrics()) -python_metrics = get_metrics_for("python") +# +# Python GC metrics +# -gc_time = python_metrics.register_distribution("gc_time", labels=["gen"]) -gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"]) -python_metrics.register_callback( - "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"] +gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"]) +gc_time = Histogram( + "python_gc_time", + "Time taken to GC (sec)", + ["gen"], + buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50, + 5.00, 7.50, 15.00, 30.00, 45.00, 60.00], ) -reactor_metrics = get_metrics_for("python.twisted.reactor") -tick_time = reactor_metrics.register_distribution("tick_time") -pending_calls_metric = reactor_metrics.register_distribution("pending_calls") + +class GCCounts(object): + + def collect(self): + cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) + for n, m in enumerate(gc.get_count()): + cm.add_metric([str(n)], m) + + yield cm + + +if not running_on_pypy: + REGISTRY.register(GCCounts()) + +# +# Twisted reactor metrics +# + +tick_time = Histogram( + "python_twisted_reactor_tick_time", + "Tick time of the Twisted reactor (sec)", + buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5], +) +pending_calls_metric = Histogram( + "python_twisted_reactor_pending_calls", + "Pending calls", + buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000], +) + +# +# Federation Metrics +# + +sent_edus_counter = Counter("synapse_federation_client_sent_edus", "") + +sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "") + +events_processed_counter = Counter("synapse_federation_client_events_processed", "") + +# Used to track where various components have processed in the event stream, +# e.g. federation sending, appservice sending, etc. +event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"]) + +# Used to track the current max events stream position +event_persisted_position = Gauge("synapse_event_persisted_position", "") + +# Used to track the received_ts of the last event processed by various +# components +event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"]) + +# Used to track the lag processing events. This is the time difference +# between the last processed event's received_ts and the time it was +# finished being processed. +event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"]) + +last_ticked = time.time() + + +class ReactorLastSeenMetric(object): + + def collect(self): + cm = GaugeMetricFamily( + "python_twisted_reactor_last_seen", + "Seconds since the Twisted reactor was last seen", + ) + cm.add_metric([], time.time() - last_ticked) + yield cm + + +REGISTRY.register(ReactorLastSeenMetric()) def runUntilCurrentTimer(func): @@ -146,12 +226,25 @@ def runUntilCurrentTimer(func): num_pending += 1 num_pending += len(reactor.threadCallQueue) - - start = time.time() * 1000 + start = time.time() ret = func(*args, **kwargs) - end = time.time() * 1000 - tick_time.inc_by(end - start) - pending_calls_metric.inc_by(num_pending) + end = time.time() + + # record the amount of wallclock time spent running pending calls. + # This is a proxy for the actual amount of time between reactor polls, + # since about 25% of time is actually spent running things triggered by + # I/O events, but that is harder to capture without rewriting half the + # reactor. + tick_time.observe(end - start) + pending_calls_metric.observe(num_pending) + + # Update the time we last ticked, for the metric to test whether + # Synapse's reactor has frozen + global last_ticked + last_ticked = end + + if running_on_pypy: + return ret # Check if we need to do a manual GC (since its been disabled), and do # one if necessary. @@ -161,12 +254,12 @@ def runUntilCurrentTimer(func): if threshold[i] < counts[i]: logger.info("Collecting gc %d", i) - start = time.time() * 1000 + start = time.time() unreachable = gc.collect(i) - end = time.time() * 1000 + end = time.time() - gc_time.inc_by(end - start, i) - gc_unreachable.inc_by(unreachable, i) + gc_time.labels(i).observe(end - start) + gc_unreachable.labels(i).set(unreachable) return ret @@ -185,6 +278,7 @@ try: # We manually run the GC each reactor tick so that we can get some metrics # about time spent doing GC, - gc.disable() + if not running_on_pypy: + gc.disable() except AttributeError: pass diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py new file mode 100644 index 0000000000..9d820e44a6 --- /dev/null +++ b/synapse/metrics/background_process_metrics.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily + +from twisted.internet import defer + +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext + +_background_process_start_count = Counter( + "synapse_background_process_start_count", + "Number of background processes started", + ["name"], +) + +# we set registry=None in all of these to stop them getting registered with +# the default registry. Instead we collect them all via the CustomCollector, +# which ensures that we can update them before they are collected. +# +_background_process_ru_utime = Counter( + "synapse_background_process_ru_utime_seconds", + "User CPU time used by background processes, in seconds", + ["name"], + registry=None, +) + +_background_process_ru_stime = Counter( + "synapse_background_process_ru_stime_seconds", + "System CPU time used by background processes, in seconds", + ["name"], + registry=None, +) + +_background_process_db_txn_count = Counter( + "synapse_background_process_db_txn_count", + "Number of database transactions done by background processes", + ["name"], + registry=None, +) + +_background_process_db_txn_duration = Counter( + "synapse_background_process_db_txn_duration_seconds", + ("Seconds spent by background processes waiting for database " + "transactions, excluding scheduling time"), + ["name"], + registry=None, +) + +_background_process_db_sched_duration = Counter( + "synapse_background_process_db_sched_duration_seconds", + "Seconds spent by background processes waiting for database connections", + ["name"], + registry=None, +) + +# map from description to a counter, so that we can name our logcontexts +# incrementally. (It actually duplicates _background_process_start_count, but +# it's much simpler to do so than to try to combine them.) +_background_process_counts = dict() # type: dict[str, int] + +# map from description to the currently running background processes. +# +# it's kept as a dict of sets rather than a big set so that we can keep track +# of process descriptions that no longer have any active processes. +_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] + + +class _Collector(object): + """A custom metrics collector for the background process metrics. + + Ensures that all of the metrics are up-to-date with any in-flight processes + before they are returned. + """ + def collect(self): + background_process_in_flight_count = GaugeMetricFamily( + "synapse_background_process_in_flight_count", + "Number of background processes in flight", + labels=["name"], + ) + + for desc, processes in six.iteritems(_background_processes): + background_process_in_flight_count.add_metric( + (desc,), len(processes), + ) + for process in processes: + process.update_metrics() + + yield background_process_in_flight_count + + # now we need to run collect() over each of the static Counters, and + # yield each metric they return. + for m in ( + _background_process_ru_utime, + _background_process_ru_stime, + _background_process_db_txn_count, + _background_process_db_txn_duration, + _background_process_db_sched_duration, + ): + for r in m.collect(): + yield r + + +REGISTRY.register(_Collector()) + + +class _BackgroundProcess(object): + def __init__(self, desc, ctx): + self.desc = desc + self._context = ctx + self._reported_stats = None + + def update_metrics(self): + """Updates the metrics with values from this process.""" + new_stats = self._context.get_resource_usage() + if self._reported_stats is None: + diff = new_stats + else: + diff = new_stats - self._reported_stats + self._reported_stats = new_stats + + _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) + _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) + _background_process_db_txn_count.labels(self.desc).inc( + diff.db_txn_count, + ) + _background_process_db_txn_duration.labels(self.desc).inc( + diff.db_txn_duration_sec, + ) + _background_process_db_sched_duration.labels(self.desc).inc( + diff.db_sched_duration_sec, + ) + + +def run_as_background_process(desc, func, *args, **kwargs): + """Run the given function in its own logcontext, with resource metrics + + This should be used to wrap processes which are fired off to run in the + background, instead of being associated with a particular request. + + Args: + desc (str): a description for this background process type + func: a function, which may return a Deferred + args: positional args for func + kwargs: keyword args for func + + Returns: None + """ + @defer.inlineCallbacks + def run(): + count = _background_process_counts.get(desc, 0) + _background_process_counts[desc] = count + 1 + _background_process_start_count.labels(desc).inc() + + with LoggingContext(desc) as context: + context.request = "%s-%i" % (desc, count) + proc = _BackgroundProcess(desc, context) + _background_processes.setdefault(desc, set()).add(proc) + try: + yield func(*args, **kwargs) + finally: + proc.update_metrics() + _background_processes[desc].remove(proc) + + with PreserveLoggingContext(): + run() diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py deleted file mode 100644 index e87b2b80a7..0000000000 --- a/synapse/metrics/metric.py +++ /dev/null @@ -1,195 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from itertools import chain - - -# TODO(paul): I can't believe Python doesn't have one of these -def map_concat(func, items): - # flatten a list-of-lists - return list(chain.from_iterable(map(func, items))) - - -class BaseMetric(object): - - def __init__(self, name, labels=[]): - self.name = name - self.labels = labels # OK not to clone as we never write it - - def dimension(self): - return len(self.labels) - - def is_scalar(self): - return not len(self.labels) - - def _render_labelvalue(self, value): - # TODO: some kind of value escape - return '"%s"' % (value) - - def _render_key(self, values): - if self.is_scalar(): - return "" - return "{%s}" % ( - ",".join(["%s=%s" % (k, self._render_labelvalue(v)) - for k, v in zip(self.labels, values)]) - ) - - -class CounterMetric(BaseMetric): - """The simplest kind of metric; one that stores a monotonically-increasing - integer that counts events.""" - - def __init__(self, *args, **kwargs): - super(CounterMetric, self).__init__(*args, **kwargs) - - self.counts = {} - - # Scalar metrics are never empty - if self.is_scalar(): - self.counts[()] = 0 - - def inc_by(self, incr, *values): - if len(values) != self.dimension(): - raise ValueError( - "Expected as many values to inc() as labels (%d)" % (self.dimension()) - ) - - # TODO: should assert that the tag values are all strings - - if values not in self.counts: - self.counts[values] = incr - else: - self.counts[values] += incr - - def inc(self, *values): - self.inc_by(1, *values) - - def render_item(self, k): - return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] - - def render(self): - return map_concat(self.render_item, sorted(self.counts.keys())) - - -class CallbackMetric(BaseMetric): - """A metric that returns the numeric value returned by a callback whenever - it is rendered. Typically this is used to implement gauges that yield the - size or other state of some in-memory object by actively querying it.""" - - def __init__(self, name, callback, labels=[]): - super(CallbackMetric, self).__init__(name, labels=labels) - - self.callback = callback - - def render(self): - value = self.callback() - - if self.is_scalar(): - return ["%s %.12g" % (self.name, value)] - - return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) - for k in sorted(value.keys())] - - -class DistributionMetric(object): - """A combination of an event counter and an accumulator, which counts - both the number of events and accumulates the total value. Typically this - could be used to keep track of method-running times, or other distributions - of values that occur in discrete occurances. - - TODO(paul): Try to export some heatmap-style stats? - """ - - def __init__(self, name, *args, **kwargs): - self.counts = CounterMetric(name + ":count", **kwargs) - self.totals = CounterMetric(name + ":total", **kwargs) - - def inc_by(self, inc, *values): - self.counts.inc(*values) - self.totals.inc_by(inc, *values) - - def render(self): - return self.counts.render() + self.totals.render() - - -class CacheMetric(object): - __slots__ = ("name", "cache_name", "hits", "misses", "size_callback") - - def __init__(self, name, size_callback, cache_name): - self.name = name - self.cache_name = cache_name - - self.hits = 0 - self.misses = 0 - - self.size_callback = size_callback - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def render(self): - size = self.size_callback() - hits = self.hits - total = self.misses + self.hits - - return [ - """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), - """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), - """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), - ] - - -class MemoryUsageMetric(object): - """Keeps track of the current memory usage, using psutil. - - The class will keep the current min/max/sum/counts of rss over the last - WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second - """ - - UPDATE_HZ = 2 # number of times to get memory per second - WINDOW_SIZE_SEC = 30 # the size of the window in seconds - - def __init__(self, hs, psutil): - clock = hs.get_clock() - self.memory_snapshots = [] - - self.process = psutil.Process() - - clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ) - - def _update_curr_values(self): - max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC - self.memory_snapshots.append(self.process.memory_info().rss) - self.memory_snapshots[:] = self.memory_snapshots[-max_size:] - - def render(self): - if not self.memory_snapshots: - return [] - - max_rss = max(self.memory_snapshots) - min_rss = min(self.memory_snapshots) - sum_rss = sum(self.memory_snapshots) - len_rss = len(self.memory_snapshots) - - return [ - "process_psutil_rss:max %d" % max_rss, - "process_psutil_rss:min %d" % min_rss, - "process_psutil_rss:total %d" % sum_rss, - "process_psutil_rss:count %d" % len_rss, - ] diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py deleted file mode 100644 index 6fec3de399..0000000000 --- a/synapse/metrics/process_collector.py +++ /dev/null @@ -1,122 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - - -TICKS_PER_SEC = 100 -BYTES_PER_PAGE = 4096 - -HAVE_PROC_STAT = os.path.exists("/proc/stat") -HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") -HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits") -HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd") - -# Field indexes from /proc/self/stat, taken from the proc(5) manpage -STAT_FIELDS = { - "utime": 14, - "stime": 15, - "starttime": 22, - "vsize": 23, - "rss": 24, -} - - -stats = {} - -# In order to report process_start_time_seconds we need to know the -# machine's boot time, because the value in /proc/self/stat is relative to -# this -boot_time = None -if HAVE_PROC_STAT: - with open("/proc/stat") as _procstat: - for line in _procstat: - if line.startswith("btime "): - boot_time = int(line.split()[1]) - - -def update_resource_metrics(): - if HAVE_PROC_SELF_STAT: - global stats - with open("/proc/self/stat") as s: - line = s.read() - # line is PID (command) more stats go here ... - raw_stats = line.split(") ", 1)[1].split(" ") - - for (name, index) in STAT_FIELDS.iteritems(): - # subtract 3 from the index, because proc(5) is 1-based, and - # we've lost the first two fields in PID and COMMAND above - stats[name] = int(raw_stats[index - 3]) - - -def _count_fds(): - # Not every OS will have a /proc/self/fd directory - if not HAVE_PROC_SELF_FD: - return 0 - - return len(os.listdir("/proc/self/fd")) - - -def register_process_collector(process_metrics): - process_metrics.register_collector(update_resource_metrics) - - if HAVE_PROC_SELF_STAT: - process_metrics.register_callback( - "cpu_user_seconds_total", - lambda: float(stats["utime"]) / TICKS_PER_SEC - ) - process_metrics.register_callback( - "cpu_system_seconds_total", - lambda: float(stats["stime"]) / TICKS_PER_SEC - ) - process_metrics.register_callback( - "cpu_seconds_total", - lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC - ) - - process_metrics.register_callback( - "virtual_memory_bytes", - lambda: int(stats["vsize"]) - ) - process_metrics.register_callback( - "resident_memory_bytes", - lambda: int(stats["rss"]) * BYTES_PER_PAGE - ) - - process_metrics.register_callback( - "start_time_seconds", - lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC - ) - - if HAVE_PROC_SELF_FD: - process_metrics.register_callback( - "open_fds", - lambda: _count_fds() - ) - - if HAVE_PROC_SELF_LIMITS: - def _get_max_fds(): - with open("/proc/self/limits") as limits: - for line in limits: - if not line.startswith("Max open files "): - continue - # Line is Max open files $SOFT $HARD - return int(line.split()[3]) - return None - - process_metrics.register_callback( - "max_fds", - lambda: _get_max_fds() - ) diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py index 870f400600..9789359077 100644 --- a/synapse/metrics/resource.py +++ b/synapse/metrics/resource.py @@ -13,27 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource - -import synapse.metrics - +from prometheus_client.twisted import MetricsResource METRICS_PREFIX = "/_synapse/metrics" - -class MetricsResource(Resource): - isLeaf = True - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.hs = hs - - def render_GET(self, request): - response = synapse.metrics.render_all() - - request.setHeader("Content-Type", "text/plain") - request.setHeader("Content-Length", str(len(response))) - - # Encode as UTF-8 (default) - return response.encode() +__all__ = ["MetricsResource", "METRICS_PREFIX"] diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py new file mode 100644 index 0000000000..097c844d31 --- /dev/null +++ b/synapse/module_api/__init__.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.types import UserID + + +class ModuleApi(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, auth_handler): + self.hs = hs + + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = auth_handler + + def get_user_by_req(self, req, allow_guest=False): + """Check the access_token provided for a request + + Args: + req (twisted.web.server.Request): Incoming HTTP request + allow_guest (bool): True if guest users should be allowed. If this + is False, and the access token is for a guest user, an + AuthError will be thrown + Returns: + twisted.internet.defer.Deferred[synapse.types.Requester]: + the requester for this request + Raises: + synapse.api.errors.AuthError: if no user by that token exists, + or the token is invalid. + """ + return self._auth.get_user_by_req(req, allow_guest) + + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id + + Returns: + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. + """ + return self._auth_handler.check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) + + @defer.inlineCallbacks + def invalidate_access_token(self, access_token): + """Invalidate an access token for a user + + Args: + access_token(str): access token + + Returns: + twisted.internet.defer.Deferred - resolves once the access token + has been removed. + + Raises: + synapse.api.errors.AuthError: the access token is invalid + """ + # see if the access token corresponds to a device + user_info = yield self._auth.get_user_by_access_token(access_token) + device_id = user_info.get("device_id") + user_id = user_info["user"].to_string() + if device_id: + # delete the device, which will also delete its access tokens + yield self.hs.get_device_handler().delete_device(user_id, device_id) + else: + # no associated device. Just delete the access token. + yield self._auth_handler.delete_access_token(access_token) + + def run_db_interaction(self, desc, func, *args, **kwargs): + """Run a function with a database connection + + Args: + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object + as well as *args and **kwargs + *args: positional args to be passed to func + **kwargs: named args to be passed to func + + Returns: + Deferred[object]: result of func + """ + return self._store.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/notifier.py b/synapse/notifier.py index 385208b574..e650c3e494 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -13,34 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from collections import namedtuple + +from prometheus_client import Counter + from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state - -from synapse.util import DeferredTimedOutError +from synapse.metrics import LaterGauge +from synapse.types import StreamToken +from synapse.util.async import ( + DeferredTimeoutError, + ObservableDeferred, + add_timeout_to_deferred, +) +from synapse.util.logcontext import PreserveLoggingContext, run_in_background from synapse.util.logutils import log_function -from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.metrics import Measure -from synapse.types import StreamToken from synapse.visibility import filter_events_for_client -import synapse.metrics - -from collections import namedtuple - -import logging - logger = logging.getLogger(__name__) -metrics = synapse.metrics.get_metrics_for(__name__) +notified_events_counter = Counter("synapse_notifier_notified_events", "") -notified_events_counter = metrics.register_counter("notified_events") - -users_woken_by_stream_counter = metrics.register_counter( - "users_woken_by_stream", labels=["stream"] -) +users_woken_by_stream_counter = Counter( + "synapse_notifier_users_woken_by_stream", "", ["stream"]) # TODO(paul): Should be shared somewhere @@ -105,7 +105,7 @@ class _NotifierUserStream(object): self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred - users_woken_by_stream_counter.inc(stream_key) + users_woken_by_stream_counter.labels(stream_key).inc() with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) @@ -144,6 +144,7 @@ class _NotifierUserStream(object): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 class Notifier(object): @@ -159,6 +160,7 @@ class Notifier(object): self.user_to_user_stream = {} self.room_to_user_streams = {} + self.hs = hs self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -193,14 +195,14 @@ class Notifier(object): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) - metrics.register_callback("listeners", count_listeners) + LaterGauge("synapse_notifier_listeners", "", [], count_listeners) - metrics.register_callback( - "rooms", + LaterGauge( + "synapse_notifier_rooms", "", [], lambda: count(bool, self.room_to_user_streams.values()), ) - metrics.register_callback( - "users", + LaterGauge( + "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream), ) @@ -250,14 +252,10 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - preserve_fn(self.appservice_handler.notify_interested_services)( - room_stream_id - ) + run_in_background(self._notify_app_services, room_stream_id) if self.federation_sender: - preserve_fn(self.federation_sender.notify_new_events)( - room_stream_id - ) + self.federation_sender.notify_new_events(room_stream_id) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -268,8 +266,15 @@ class Notifier(object): rooms=[event.room_id], ) + @defer.inlineCallbacks + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") + def on_new_event(self, stream_key, new_token, users=[], rooms=[]): - """ Used to inform listeners that something has happend event wise. + """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ @@ -289,7 +294,7 @@ class Notifier(object): for user_stream in user_streams: try: user_stream.notify(stream_key, new_token, time_now_ms) - except: + except Exception: logger.exception("Failed to notify listener") self.notify_replication() @@ -297,8 +302,7 @@ class Notifier(object): def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" - with PreserveLoggingContext(): - self.notify_replication() + self.notify_replication() @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, @@ -333,11 +337,13 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) + add_timeout_to_deferred( + listener.deferred, + (end_time - now) / 1000., + self.hs.get_reactor(), + ) with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) + yield listener.deferred current_token = user_stream.current_token @@ -348,7 +354,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimedOutError: + except DeferredTimeoutError: break except defer.CancelledError: break @@ -516,8 +522,14 @@ class Notifier(object): self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) - for cb in self.replication_callbacks: - preserve_fn(cb)() + # the callbacks may well outlast the current request, so we run + # them in the sentinel logcontext. + # + # (ideally it would be up to the callbacks to know if they were + # starting off background processes and drop the logcontext + # accordingly, but that requires more changes) + for cb in self.replication_callbacks: + cb() @defer.inlineCallbacks def wait_for_replication(self, callback, timeout): @@ -547,13 +559,15 @@ class Notifier(object): if end_time <= now: break + add_timeout_to_deferred( + listener.deferred.addTimeout, + (end_time - now) / 1000., + self.hs.get_reactor(), + ) try: with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) - except DeferredTimedOutError: + yield listener.deferred + except DeferredTimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index fe09d50d55..a5de75c48a 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from .bulk_push_rule_evaluator import BulkPushRuleEvaluator +from twisted.internet import defer from synapse.util.metrics import Measure -import logging +from .bulk_push_rule_evaluator import BulkPushRuleEvaluator logger = logging.getLogger(__name__) @@ -40,10 +40,6 @@ class ActionGenerator(object): @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - actions_by_user = yield self.bulk_evaluator.action_for_event_by_user( + yield self.bulk_evaluator.action_for_event_by_user( event, context ) - - context.push_actions = [ - (uid, actions) for uid, actions in actions_by_user.iteritems() - ] diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 85effdfa46..8f0682c948 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP import copy +from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP + def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules @@ -38,7 +40,7 @@ def list_with_base_rules(rawrules): rawrules = [r for r in rawrules if r['priority_class'] >= 0] # shove the server default rules for each kind onto the end of each - current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] + current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1] ruleslist.extend(make_base_prepend_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules @@ -238,6 +240,28 @@ BASE_APPEND_OVERRIDE_RULES = [ } ] }, + { + 'rule_id': 'global/override/.m.rule.roomnotif', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '@room', + '_id': '_roomnotif_content', + }, + { + 'kind': 'sender_notification_permission', + 'key': 'room', + '_id': '_roomnotif_pl', + }, + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': True, + } + ] + } ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9a96e6fe8f..1d14d3639c 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,22 @@ # limitations under the License. import logging +from collections import namedtuple -from twisted.internet import defer +from six import iteritems, itervalues -from .push_rule_evaluator import PushRuleEvaluatorForEvent +from prometheus_client import Counter + +from twisted.internet import defer -from synapse.visibility import filter_events_for_clients_context from synapse.api.constants import EventTypes, Membership -from synapse.util.caches.descriptors import cached +from synapse.event_auth import get_user_power_level +from synapse.state import POWER_KEY from synapse.util.async import Linearizer +from synapse.util.caches import register_cache +from synapse.util.caches.descriptors import cached -from collections import namedtuple - +from .push_rule_evaluator import PushRuleEvaluatorForEvent logger = logging.getLogger(__name__) @@ -33,6 +38,20 @@ logger = logging.getLogger(__name__) rules_by_room = {} +push_rules_invalidation_counter = Counter( + "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "") +push_rules_state_size_counter = Counter( + "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "") + +# Measures whether we use the fast path of using state deltas, or if we have to +# recalculate from scratch +push_rules_delta_state_cache_metric = register_cache( + "cache", + "push_rules_delta_state_cache_metric", + cache=[], # Meaningless size, as this isn't a cache that stores values +) + + class BulkPushRuleEvaluator(object): """Calculates the outcome of push rules for an event for all users in the room at once. @@ -41,6 +60,13 @@ class BulkPushRuleEvaluator(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.auth = hs.get_auth() + + self.room_push_rule_cache_metrics = register_cache( + "cache", + "room_push_rule_cache", + cache=[], # Meaningless size, as this isn't a cache that stores values + ) @defer.inlineCallbacks def _get_rules_for_event(self, event, context): @@ -79,37 +105,69 @@ class BulkPushRuleEvaluator(object): # It's important that RulesForRoom gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) - return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache) + return RulesForRoom( + self.hs, room_id, self._get_rules_for_room.cache, + self.room_push_rule_cache_metrics, + ) + + @defer.inlineCallbacks + def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = yield context.get_prev_state_ids(self.store) + pl_event_id = prev_state_ids.get(POWER_KEY) + if pl_event_id: + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case + pl_event = yield self.store.get_event(pl_event_id) + auth_events = {POWER_KEY: pl_event} + else: + auth_events_ids = yield self.auth.compute_auth_events( + event, prev_state_ids, for_verification=False, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in itervalues(auth_events) + } + + sender_level = get_user_power_level(event.sender, auth_events) + + pl_event = auth_events.get(POWER_KEY) + + defer.returnValue((pl_event.content if pl_event else {}, sender_level)) @defer.inlineCallbacks def action_for_event_by_user(self, event, context): - """Given an event and context, evaluate the push rules and return - the results + """Given an event and context, evaluate the push rules and insert the + results into the event_push_actions_staging table. Returns: - dict of user_id -> action + Deferred """ rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} - # None of these users can be peeking since this list of users comes - # from the set of users in the room, so we know for sure they're all - # actually in the room. - user_tuples = [(u, False) for u in rules_by_user] - - filtered_by_user = yield filter_events_for_clients_context( - self.store, user_tuples, [event], {event.event_id: context} - ) - room_members = yield self.store.get_joined_users_from_context( event, context ) - evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) + (power_levels, sender_power_level) = ( + yield self._get_power_levels_and_sender_level(event, context) + ) + + evaluator = PushRuleEvaluatorForEvent( + event, len(room_members), sender_power_level, power_levels, + ) condition_cache = {} - for uid, rules in rules_by_user.iteritems(): + for uid, rules in iteritems(rules_by_user): + if event.sender == uid: + continue + + if not event.is_state(): + is_ignored = yield self.store.is_ignored_by(event.sender, uid) + if is_ignored: + continue + display_name = None profile_info = room_members.get(uid) if profile_info: @@ -121,13 +179,6 @@ class BulkPushRuleEvaluator(object): if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) - filtered = filtered_by_user[uid] - if len(filtered) == 0: - continue - - if filtered[0].sender == uid: - continue - for rule in rules: if 'enabled' in rule and not rule['enabled']: continue @@ -138,9 +189,16 @@ class BulkPushRuleEvaluator(object): if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] if actions and 'notify' in actions: + # Push rules say we should notify the user of this event actions_by_user[uid] = actions break - defer.returnValue(actions_by_user) + + # Mark in the DB staging area the push actions for users who should be + # notified for this event. (This will then get handled when we persist + # the event) + yield self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, + ) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -170,17 +228,19 @@ class RulesForRoom(object): the entire cache for the room. """ - def __init__(self, hs, room_id, rules_for_room_cache): + def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): """ Args: hs (HomeServer) room_id (str) rules_for_room_cache(Cache): The cache object that caches these RoomsForUser objects. + room_push_rule_cache_metrics (CacheMetric) """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() + self.room_push_rule_cache_metrics = room_push_rule_cache_metrics self.linearizer = Linearizer(name="rules_for_room") @@ -222,11 +282,19 @@ class RulesForRoom(object): """ state_group = context.state_group + if state_group and self.state_group == state_group: + logger.debug("Using cached rules for %r", self.room_id) + self.room_push_rule_cache_metrics.inc_hits() + defer.returnValue(self.rules_by_user) + with (yield self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) + self.room_push_rule_cache_metrics.inc_hits() defer.returnValue(self.rules_by_user) + self.room_push_rule_cache_metrics.inc_misses() + ret_rules_by_user = {} missing_member_event_ids = {} if state_group and self.state_group == context.prev_group: @@ -234,8 +302,13 @@ class RulesForRoom(object): # results. ret_rules_by_user = self.rules_by_user current_state_ids = context.delta_ids + + push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = context.current_state_ids + current_state_ids = yield context.get_current_state_ids(self.store) + push_rules_delta_state_cache_metric.inc_misses() + + push_rules_state_size_counter.inc(len(current_state_ids)) logger.debug( "Looking for member changes in %r %r", state_group, current_state_ids @@ -282,6 +355,14 @@ class RulesForRoom(object): yield self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) + else: + # The push rules didn't change but lets update the cache anyway + self.update_cache( + self.sequence, + members={}, # There were no membership changes + rules_by_user=ret_rules_by_user, + state_group=state_group + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -324,7 +405,7 @@ class RulesForRoom(object): # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: - for event_id in member_event_ids.itervalues(): + for event_id in itervalues(member_event_ids): if event_id == event.event_id: members[event_id] = (event.state_key, event.membership) @@ -332,7 +413,7 @@ class RulesForRoom(object): logger.debug("Found members %r: %r", self.room_id, members.values()) interested_in_user_ids = set( - user_id for user_id, membership in members.itervalues() + user_id for user_id, membership in itervalues(members) if membership == Membership.JOIN ) @@ -344,7 +425,7 @@ class RulesForRoom(object): ) user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher + uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher ) logger.debug("With pushers: %r", user_ids) @@ -365,7 +446,7 @@ class RulesForRoom(object): ) ret_rules_by_user.update( - item for item in rules_by_user.iteritems() if item[0] is not None + item for item in iteritems(rules_by_user) if item[0] is not None ) self.update_cache(sequence, members, ret_rules_by_user, state_group) @@ -380,6 +461,7 @@ class RulesForRoom(object): self.state_group = object() self.member_map = {} self.rules_by_user = {} + push_rules_invalidation_counter.inc() def update_cache(self, sequence, members, rules_by_user, state_group): if sequence == self.sequence: diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index e0331b2d2d..ecbf364a5e 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push.rulekinds import ( - PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -) - import copy +from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP + def format_push_rules_for_user(user, ruleslist): """Converts a list of rawrules and a enabled map into nested dictionaries diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index a69dda7b09..d746371420 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from twisted.internet.error import AlreadyCalled, AlreadyCancelled - import logging -from synapse.util.metrics import Measure -from synapse.util.logcontext import LoggingContext +from twisted.internet import defer +from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.util.logcontext import LoggingContext +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -77,10 +76,13 @@ class EmailPusher(object): @defer.inlineCallbacks def on_started(self): if self.mailer is not None: - self.throttle_params = yield self.store.get_throttle_params_by_room( - self.pusher_id - ) - yield self._process() + try: + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) + yield self._process() + except Exception: + logger.exception("Error starting email pusher") def on_stop(self): if self.timed_call: @@ -121,7 +123,7 @@ class EmailPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break @@ -196,7 +198,7 @@ class EmailPusher(object): self.timed_call = None if soonest_due_at is not None: - self.timed_call = reactor.callLater( + self.timed_call = self.hs.get_reactor().callLater( self.seconds_until(soonest_due_at), self.on_timer ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 8a5d473108..81e18bcf7d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging -from synapse.push import PusherConfigException +from prometheus_client import Counter -from twisted.internet import defer, reactor +from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled -import logging -import push_rule_evaluator -import push_tools - +from synapse.push import PusherConfigException from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure +from . import push_rule_evaluator, push_tools + logger = logging.getLogger(__name__) +http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "") + +http_push_failed_counter = Counter("synapse_http_httppusher_http_pushes_failed", "") + class HttpPusher(object): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes @@ -84,7 +89,10 @@ class HttpPusher(object): @defer.inlineCallbacks def on_started(self): - yield self._process() + try: + yield self._process() + except Exception: + logger.exception("Error starting http pusher") @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): @@ -131,7 +139,7 @@ class HttpPusher(object): starting_max_ordering = self.max_stream_ordering try: yield self._unsafe_process() - except: + except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: break @@ -151,9 +159,16 @@ class HttpPusher(object): self.user_id, self.last_stream_ordering, self.max_stream_ordering ) + logger.info( + "Processing %i unprocessed push actions for %s starting at " + "stream_ordering %s", + len(unprocessed), self.name, self.last_stream_ordering, + ) + for push_action in unprocessed: processed = yield self._process_one(push_action) if processed: + http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action['stream_ordering'] yield self.store.update_pusher_last_stream_ordering_and_success( @@ -168,6 +183,7 @@ class HttpPusher(object): self.failing_since ) else: + http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() yield self.store.update_pusher_failing_since( @@ -204,7 +220,9 @@ class HttpPusher(object): ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) - self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) + self.timed_call = self.hs.get_reactor().callLater( + self.backoff_delay, self.on_timer + ) self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) break @@ -244,6 +262,26 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): + if self.data.get('format') == 'event_id_only': + d = { + 'notification': { + 'event_id': event.event_id, + 'room_id': event.room_id, + 'counts': { + 'unread': badge, + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + } + ] + } + } + defer.returnValue(d) + ctx = yield push_tools.get_context_for_event( self.store, self.state_handler, event, self.user_id ) @@ -275,7 +313,7 @@ class HttpPusher(object): if event.type == 'm.room.member': d['notification']['membership'] = event.content['membership'] d['notification']['user_is_target'] = event.state_key == self.user_id - if not self.hs.config.push_redact_content and 'content' in event: + if self.hs.config.push_include_content and 'content' in event: d['notification']['content'] = event.content # We no longer send aliases separately, instead, we send the human @@ -294,8 +332,11 @@ class HttpPusher(object): defer.returnValue([]) try: resp = yield self.http_client.post_json_get_json(self.url, notification_dict) - except: - logger.warn("Failed to push %s ", self.url) + except Exception: + logger.warn( + "Failed to push event %s to %s", + event.event_id, self.name, exc_info=True, + ) defer.returnValue(False) rejected = [] if 'rejected' in resp: @@ -304,7 +345,7 @@ class HttpPusher(object): @defer.inlineCallbacks def _send_badge(self, badge): - logger.info("Sending updated badge count %d to %r", badge, self.user_id) + logger.info("Sending updated badge count %d to %s", badge, self.name) d = { 'notification': { 'id': '', @@ -325,8 +366,11 @@ class HttpPusher(object): } try: resp = yield self.http_client.post_json_get_json(self.url, d) - except: - logger.exception("Failed to push %s ", self.url) + except Exception: + logger.warn( + "Failed to send badge count to %s", + self.name, exc_info=True, + ) defer.returnValue(False) rejected = [] if 'rejected' in resp: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b5cd9b426a..9d601208fd 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -13,30 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer -from twisted.mail.smtp import sendmail - -import email.utils import email.mime.multipart -from email.mime.text import MIMEText +import email.utils +import logging +import time +import urllib from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText -from synapse.util.async import concurrently_execute +import bleach +import jinja2 + +from twisted.internet import defer +from twisted.mail.smtp import sendmail + +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError from synapse.push.presentable_names import ( - calculate_room_name, name_from_member_event, descriptor_from_member_events + calculate_room_name, + descriptor_from_member_events, + name_from_member_event, ) from synapse.types import UserID -from synapse.api.errors import StoreError -from synapse.api.constants import EventTypes +from synapse.util.async import concurrently_execute from synapse.visibility import filter_events_for_client -import jinja2 -import bleach - -import time -import urllib - -import logging logger = logging.getLogger(__name__) @@ -229,7 +230,8 @@ class Mailer(object): if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]: prev_messages = room_vars['notifs'][-1]['messages'] for message in notifvars['messages']: - pm = filter(lambda pm: pm['id'] == message['id'], prev_messages) + pm = list(filter(lambda pm: pm['id'] == message['id'], + prev_messages)) if pm: if not message["is_historical"]: pm[0]["is_historical"] = False diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 277da3cd35..eef6e18c2e 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -import re import logging +import re + +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -113,7 +113,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True # so find out who is in the room that isn't the user. if "m.room.member" in room_state_bytype_ids: member_events = yield store.get_events( - room_state_bytype_ids["m.room.member"].values() + list(room_state_bytype_ids["m.room.member"].values()) ) all_members = [ ev for ev in member_events.values() diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 172c27c137..2bd321d530 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ import logging import re +from six import string_types + from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -29,6 +32,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") def _room_member_count(ev, condition, room_member_count): + return _test_ineq_condition(condition, room_member_count) + + +def _sender_notification_permission(ev, condition, sender_power_level, power_levels): + notif_level_key = condition.get('key') + if notif_level_key is None: + return False + + notif_levels = power_levels.get('notifications', {}) + room_notif_level = notif_levels.get(notif_level_key, 50) + + return sender_power_level >= room_notif_level + + +def _test_ineq_condition(condition, number): if 'is' not in condition: return False m = INEQUALITY_EXPR.match(condition['is']) @@ -41,15 +59,15 @@ def _room_member_count(ev, condition, room_member_count): rhs = int(rhs) if ineq == '' or ineq == '==': - return room_member_count == rhs + return number == rhs elif ineq == '<': - return room_member_count < rhs + return number < rhs elif ineq == '>': - return room_member_count > rhs + return number > rhs elif ineq == '>=': - return room_member_count >= rhs + return number >= rhs elif ineq == '<=': - return room_member_count <= rhs + return number <= rhs else: return False @@ -65,9 +83,11 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count): + def __init__(self, event, room_member_count, sender_power_level, power_levels): self._event = event self._room_member_count = room_member_count + self._sender_power_level = sender_power_level + self._power_levels = power_levels # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -81,6 +101,10 @@ class PushRuleEvaluatorForEvent(object): return _room_member_count( self._event, condition, self._room_member_count ) + elif condition['kind'] == 'sender_notification_permission': + return _sender_notification_permission( + self._event, condition, self._sender_power_level, self._power_levels, + ) else: return True @@ -128,7 +152,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (glob, word_boundary) -> regex for push. See _glob_matches regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) -register_cache("regex_push_cache", regex_cache) +register_cache("cache", "regex_push_cache", regex_cache) def _glob_matches(glob, value, word_boundary=False): @@ -183,7 +207,7 @@ def _glob_to_re(glob, word_boundary): r, ) if word_boundary: - r = r"\b%s\b" % (r,) + r = _re_word_boundary(r) return re.compile(r, flags=re.IGNORECASE) else: @@ -192,7 +216,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) elif word_boundary: r = re.escape(glob) - r = r"\b%s\b" % (r,) + r = _re_word_boundary(r) return re.compile(r, flags=re.IGNORECASE) else: @@ -200,11 +224,23 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) +def _re_word_boundary(r): + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) + + def _flatten_dict(d, prefix=[], result=None): if result is None: result = {} for key, value in d.items(): - if isinstance(value, basestring): + if isinstance(value, string_types): result[".".join(prefix + [key])] = value.lower() elif hasattr(value, "items"): _flatten_dict(value, prefix=(prefix + [key]), result=result) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 6835f54e97..8049c298c2 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -14,9 +14,8 @@ # limitations under the License. from twisted.internet import defer -from synapse.push.presentable_names import ( - calculate_room_name, name_from_member_event -) + +from synapse.push.presentable_names import calculate_room_name, name_from_member_event @defer.inlineCallbacks diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 491f27bded..fcee6d9d7e 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from httppusher import HttpPusher - import logging + +from .httppusher import HttpPusher + logger = logging.getLogger(__name__) # We try importing this if we can (it will fail if we don't @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) try: from synapse.push.emailpusher import EmailPusher from synapse.push.mailer import Mailer, load_jinja2_templates -except: +except Exception: pass diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 43cb6e9c01..36bb5bbc65 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from .pusher import PusherFactory -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.util.async import run_on_reactor - -import logging +from synapse.push.pusher import PusherFactory +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -103,23 +102,28 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id, except_access_token_id=None): - all = yield self.store.get_all_pushers() - logger.info( - "Removing all pushers for user %s except access tokens id %r", - user_id, except_access_token_id - ) - for p in all: - if p['user_name'] == user_id and p['access_token'] != except_access_token_id: + def remove_pushers_by_access_token(self, user_id, access_tokens): + """Remove the pushers for a given user corresponding to a set of + access_tokens. + + Args: + user_id (str): user to remove pushers for + access_tokens (Iterable[int]): access token *ids* to remove pushers + for + """ + tokens = set(access_tokens) + for p in (yield self.store.get_pushers_by_user_id(user_id)): + if p['access_token'] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher( + p['app_id'], p['pushkey'], p['user_name'], + ) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): - yield run_on_reactor() try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id @@ -131,18 +135,20 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_notifications)( - min_stream_id, max_stream_id + run_in_background( + p.on_new_notifications, + min_stream_id, max_stream_id, ) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) + except Exception: logger.exception("Exception in pusher on_new_notifications") @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): - yield run_on_reactor() try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive @@ -158,11 +164,16 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) + run_in_background( + p.on_new_receipts, + min_stream_id, max_stream_id, + ) ) - yield preserve_context_over_deferred(defer.gatherResults(deferreds)) - except: + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) + except Exception: logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks @@ -188,7 +199,7 @@ class PusherPool: for pusherdict in pushers: try: p = self.pusher_factory.create_pusher(pusherdict) - except: + except Exception: logger.exception("Couldn't start a pusher: caught Exception") continue if p: @@ -201,7 +212,7 @@ class PusherPool: if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p - preserve_fn(p.on_started)() + run_in_background(p.on_started) logger.info("Started pushers") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index ed7f1c89ad..987eec3ef2 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -1,5 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,36 +19,52 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) +# this dict maps from python package name to a list of modules we expect it to +# provide. +# +# the key is a "requirement specifier", as used as a parameter to `pip +# install`[1], or an `install_requires` argument to `setuptools.setup` [2]. +# +# the value is a sequence of strings; each entry should be the name of the +# python module, optionally followed by a version assertion which can be either +# ">=<ver>" or "==<ver>". +# +# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. +# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], - "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], + "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], + "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"], - "pyopenssl>=0.14": ["OpenSSL>=0.14"], + + # We use crypto.get_elliptic_curve which is only supported in >=0.15 + "pyopenssl>=0.15": ["OpenSSL>=0.15"], + "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], "daemonize": ["daemonize"], - "py-bcrypt": ["bcrypt"], + "bcrypt": ["bcrypt>=3.1.0"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], - "ujson": ["ujson"], - "blist": ["blist"], - "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], + "sortedcontainers": ["sortedcontainers"], + "pysaml2>=3.0.0": ["saml2>=3.0.0"], "pymacaroons-pynacl": ["pymacaroons"], "msgpack-python>=0.3.0": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], + "six": ["six"], + "prometheus_client": ["prometheus_client"], + "attrs": ["attr"], + "netaddr>=0.7.18": ["netaddr"], } + CONDITIONAL_REQUIREMENTS = { "web_client": { "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], }, - "preview_url": { - "netaddr>=0.7.18": ["netaddr"], - }, "email.enable_notifs": { "Jinja2>=2.8": ["Jinja2>=2.8"], "bleach>=1.4.2": ["bleach>=1.4.2"], @@ -58,6 +75,9 @@ CONDITIONAL_REQUIREMENTS = { "psutil": { "psutil>=2.0.0": ["psutil>=2.0.0"], }, + "affinity": { + "affinity": ["affinity"], + }, } diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py new file mode 100644 index 0000000000..589ee94c66 --- /dev/null +++ b/synapse/replication/http/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import JsonResource +from synapse.replication.http import membership, send_event + +REPLICATION_PREFIX = "/_synapse/replication" + + +class ReplicationRestResource(JsonResource): + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + self.register_servlets(hs) + + def register_servlets(self, hs): + send_event.register_servlets(hs, self) + membership.register_servlets(hs, self) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py new file mode 100644 index 0000000000..6bfc8a5b89 --- /dev/null +++ b/synapse/replication/http/membership.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.errors import MatrixCodeMessageException, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import Requester, UserID +from synapse.util.distributor import user_joined_room, user_left_room + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def remote_join(client, host, port, requester, remote_room_hosts, + room_id, user_id, content): + """Ask the master to do a remote join for the given user to the given room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and join via + room_id (str) + user_id (str) + content (dict): The event content to use for the join event + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + "content": content, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def remote_reject_invite(client, host, port, requester, remote_room_hosts, + room_id, user_id): + """Ask master to reject the invite for the user and room. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + remote_room_hosts (list[str]): Servers to try and reject via + room_id (str) + user_id (str) + + Returns: + Deferred + """ + uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port) + + payload = { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "room_id": room_id, + "user_id": user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def get_or_register_3pid_guest(client, host, port, requester, + medium, address, inviter_user_id): + """Ask the master to get/create a guest account for given 3PID. + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ + + uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port) + + payload = { + "requester": requester.serialize(), + "medium": medium, + "address": address, + "inviter_user_id": inviter_user_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +@defer.inlineCallbacks +def notify_user_membership_change(client, host, port, user_id, room_id, change): + """Notify master that a user has joined or left the room + + Args: + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication. + user_id (str) + room_id (str) + change (str): Either "join" or "left" + + Returns: + Deferred + """ + assert change in ("joined", "left") + + uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change) + + payload = { + "user_id": user_id, + "room_id": room_id, + } + + try: + result = yield client.post_json_get_json(uri, payload) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationRemoteJoinRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_join$")] + + def __init__(self, hs): + super(ReplicationRemoteJoinRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_join: %s into room: %s", + user_id, room_id, + ) + + yield self.federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user_id, + event_content, + ) + + defer.returnValue((200, {})) + + +class ReplicationRemoteRejectInviteRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")] + + def __init__(self, hs): + super(ReplicationRemoteRejectInviteRestServlet, self).__init__() + + self.federation_handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + room_id = content["room_id"] + user_id = content["user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "remote_reject_invite: %s out of room: %s", + user_id, room_id, + ) + + try: + event = yield self.federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id, + ) + ret = event.get_pdu_json() + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + user_id, room_id + ) + ret = {} + + defer.returnValue((200, ret)) + + +class ReplicationRegister3PIDGuestRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")] + + def __init__(self, hs): + super(ReplicationRegister3PIDGuestRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_POST(self, request): + content = parse_json_object_from_request(request) + + medium = content["medium"] + address = content["address"] + inviter_user_id = content["inviter_user_id"] + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info("get_or_register_3pid_guest: %r", content) + + ret = yield self.registeration_handler.get_or_register_3pid_guest( + medium, address, inviter_user_id, + ) + + defer.returnValue((200, ret)) + + +class ReplicationUserJoinedLeftRoomRestServlet(RestServlet): + PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")] + + def __init__(self, hs): + super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__() + + self.registeration_handler = hs.get_handlers().registration_handler + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.distributor = hs.get_distributor() + + def on_POST(self, request, change): + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + room_id = content["room_id"] + + logger.info("user membership change: %s in %s", user_id, room_id) + + user = UserID.from_string(user_id) + + if change == "joined": + user_joined_room(self.distributor, user, room_id) + elif change == "left": + user_left_room(self.distributor, user, room_id) + else: + raise Exception("Unrecognized change: %r", change) + + return (200, {}) + + +def register_servlets(hs, http_server): + ReplicationRemoteJoinRestServlet(hs).register(http_server) + ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) + ReplicationRegister3PIDGuestRestServlet(hs).register(http_server) + ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py new file mode 100644 index 0000000000..5227bc333d --- /dev/null +++ b/synapse/replication/http/send_event.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.errors import ( + CodeMessageException, + MatrixCodeMessageException, + SynapseError, +) +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import Requester, UserID +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def send_event_to_master(clock, store, client, host, port, requester, event, context, + ratelimit, extra_users): + """Send event to be handled on the master + + Args: + clock (synapse.util.Clock) + store (DataStore) + client (SimpleHttpClient) + host (str): host of master + port (int): port on master listening for HTTP replication + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ + uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( + host, port, event.event_id, + ) + + serialized_context = yield context.serialize(event, store) + + payload = { + "event": event.get_pdu_json(), + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed on + # the master, and so whether we should clean up or not. + while True: + try: + result = yield client.put_json(uri, payload) + break + except CodeMessageException as e: + if e.code != 504: + raise + + logger.warn("send_event request timed out") + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + yield clock.sleep(1) + except MatrixCodeMessageException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise SynapseError(e.code, e.msg, e.errcode) + defer.returnValue(result) + + +class ReplicationSendEventRestServlet(RestServlet): + """Handles events newly created on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_event/:event_id + + { + "event": { .. serialized event .. }, + "internal_metadata": { .. serialized internal_metadata .. }, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + "extra_users": [], + } + """ + PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")] + + def __init__(self, hs): + super(ReplicationSendEventRestServlet, self).__init__() + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + # The responses are tiny, so we may as well cache them for a while + self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) + + def on_PUT(self, request, event_id): + return self.response_cache.wrap( + event_id, + self._handle_request, + request + ) + + @defer.inlineCallbacks + def _handle_request(self, request): + with Measure(self.clock, "repl_send_event_parse"): + content = parse_json_object_from_request(request) + + event_dict = content["event"] + internal_metadata = content["internal_metadata"] + rejected_reason = content["rejected_reason"] + event = FrozenEvent(event_dict, internal_metadata, rejected_reason) + + requester = Requester.deserialize(self.store, content["requester"]) + context = yield EventContext.deserialize(self.store, content["context"]) + + ratelimit = content["ratelimit"] + extra_users = [UserID.from_string(u) for u in content["extra_users"]] + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + logger.info( + "Got event to send with ID: %s into room: %s", + event.event_id, event.room_id, + ) + + yield self.event_creation_handler.persist_and_notify_client_event( + requester, event, context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index b962641166..3f7be74e02 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker -import logging - logger = logging.getLogger(__name__) class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(hs) + super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id", diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index efbd87918e..d9ba6d69b1 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,50 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.account_data import AccountDataStore -from synapse.storage.tags import TagsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.tags import TagsWorkerStore -class SlavedAccountDataStore(BaseSlavedStore): +class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_current_token(), - ) - - get_account_data_for_user = ( - AccountDataStore.__dict__["get_account_data_for_user"] - ) - - get_global_account_data_by_type_for_users = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] - ) - get_global_account_data_by_type_for_user = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] - ) - - get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] - get_tags_for_room = ( - DataStore.get_tags_for_room.__func__ - ) - get_account_data_for_room = ( - DataStore.get_account_data_for_room.__func__ - ) - - get_updated_tags = DataStore.get_updated_tags.__func__ - get_updated_account_data_for_user = ( - DataStore.get_updated_account_data_for_user.__func__ - ) + super(SlavedAccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() @@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore): (row.data_type, row.user_id,) ) self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type,), + ) self._account_data_stream_cache.entity_has_changed( row.user_id, token ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index 0d3f31a50c..b53a4c6bd1 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.config.appservice import load_appservices -from synapse.storage.appservice import _make_exclusive_regex +from synapse.storage.appservice import ( + ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore, +) -class SlavedApplicationServiceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) - self.services_cache = load_appservices( - hs.config.server_name, - hs.config.app_service_config_files - ) - self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - - get_app_service_by_token = DataStore.get_app_service_by_token.__func__ - get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ - get_app_services = DataStore.get_app_services.__func__ - get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ - create_appservice_txn = DataStore.create_appservice_txn.__func__ - get_appservices_by_state = DataStore.get_appservices_by_state.__func__ - get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ - _get_last_txn = DataStore._get_last_txn.__func__ - complete_appservice_txn = DataStore.complete_appservice_txn.__func__ - get_appservice_state = DataStore.get_appservice_state.__func__ - set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ - set_appservice_state = DataStore.set_appservice_state.__func__ - get_if_app_services_interested_in_user = ( - DataStore.get_if_app_services_interested_in_user.__func__ - ) +class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore): + pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 65250285e8..60641f1a49 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache +from ._base import BaseSlavedStore + class SlavedClientIpStore(BaseSlavedStore): def __init__(self, db_conn, hs): @@ -29,9 +30,8 @@ class SlavedClientIpStore(BaseSlavedStore): max_entries=50000 * CACHE_SIZE_FACTOR, ) - def insert_client_ip(self, user, access_token, ip, user_agent, device_id): + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) - user_id = user.to_string() key = (user_id, access_token, ip) try: diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 6f3fb64770..87eaa53004 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker class SlavedDeviceInboxStore(BaseSlavedStore): diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 7687867aee..8206a988f7 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + class SlavedDeviceStore(BaseSlavedStore): def __init__(self, db_conn, hs): diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 7301d885f2..1d1d48709a 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.directory import DirectoryWorkerStore + from ._base import BaseSlavedStore -from synapse.storage.directory import DirectoryStore -class DirectoryStore(BaseSlavedStore): - get_aliases_for_room = DirectoryStore.__dict__[ - "get_aliases_for_room" - ] +class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 94ebbffc1b..bdb5eee4af 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker +import logging from synapse.api.constants import EventTypes -from synapse.storage import DataStore -from synapse.storage.roommember import RoomMemberStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.state import StateStore -from synapse.storage.stream import StreamStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -import logging +from synapse.storage.event_federation import EventFederationWorkerStore +from synapse.storage.event_push_actions import EventPushActionsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore +from synapse.storage.signatures import SignatureWorkerStore +from synapse.storage.state import StateGroupWorkerStore +from synapse.storage.stream import StreamWorkerStore +from synapse.storage.user_erasure_store import UserErasureWorkerStore +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -39,163 +40,34 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(BaseSlavedStore): +class SlavedEventStore(EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - self.stream_ordering_month_ago = 0 - self._stream_order_on_start = self.get_room_max_stream_ordering() + super(SlavedEventStore, self).__init__(db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. - get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] - get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] - get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"] - get_users_who_share_room_with_user = ( - RoomMemberStore.__dict__["get_users_who_share_room_with_user"] - ) - get_latest_event_ids_in_room = EventFederationStore.__dict__[ - "get_latest_event_ids_in_room" - ] - get_invited_rooms_for_user = RoomMemberStore.__dict__[ - "get_invited_rooms_for_user" - ] - get_unread_event_push_actions_by_room_for_user = ( - EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] - ) - _get_unread_counts_by_receipt_txn = ( - DataStore._get_unread_counts_by_receipt_txn.__func__ - ) - _get_unread_counts_by_pos_txn = ( - DataStore._get_unread_counts_by_pos_txn.__func__ - ) - _get_state_group_for_events = ( - StateStore.__dict__["_get_state_group_for_events"] - ) - _get_state_group_for_event = ( - StateStore.__dict__["_get_state_group_for_event"] - ) - _get_state_groups_from_groups = ( - StateStore.__dict__["_get_state_groups_from_groups"] - ) - _get_state_groups_from_groups_txn = ( - DataStore._get_state_groups_from_groups_txn.__func__ - ) - get_recent_event_ids_for_room = ( - StreamStore.__dict__["get_recent_event_ids_for_room"] - ) - get_current_state_ids = ( - StateStore.__dict__["get_current_state_ids"] - ) - get_state_group_delta = StateStore.__dict__["get_state_group_delta"] - _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] - has_room_changed_since = DataStore.has_room_changed_since.__func__ - - get_unread_push_actions_for_user_in_range_for_http = ( - DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ - ) - get_unread_push_actions_for_user_in_range_for_email = ( - DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ - ) - get_push_action_users_in_range = ( - DataStore.get_push_action_users_in_range.__func__ - ) - get_event = DataStore.get_event.__func__ - get_events = DataStore.get_events.__func__ - get_rooms_for_user_where_membership_is = ( - DataStore.get_rooms_for_user_where_membership_is.__func__ - ) - get_membership_changes_for_user = ( - DataStore.get_membership_changes_for_user.__func__ - ) - get_room_events_max_id = DataStore.get_room_events_max_id.__func__ - get_room_events_stream_for_room = ( - DataStore.get_room_events_stream_for_room.__func__ - ) - get_events_around = DataStore.get_events_around.__func__ - get_state_for_event = DataStore.get_state_for_event.__func__ - get_state_for_events = DataStore.get_state_for_events.__func__ - get_state_groups = DataStore.get_state_groups.__func__ - get_state_groups_ids = DataStore.get_state_groups_ids.__func__ - get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ - get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ - get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ - get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ - _get_joined_users_from_context = ( - RoomMemberStore.__dict__["_get_joined_users_from_context"] - ) - - get_joined_hosts = DataStore.get_joined_hosts.__func__ - _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] - - get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ - get_room_events_stream_for_rooms = ( - DataStore.get_room_events_stream_for_rooms.__func__ - ) - is_host_joined = RoomMemberStore.__dict__["is_host_joined"] - get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ - - _set_before_and_after = staticmethod(DataStore._set_before_and_after) - - _get_events = DataStore._get_events.__func__ - _get_events_from_cache = DataStore._get_events_from_cache.__func__ - - _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ - _enqueue_events = DataStore._enqueue_events.__func__ - _do_fetch = DataStore._do_fetch.__func__ - _fetch_event_rows = DataStore._fetch_event_rows.__func__ - _get_event_from_row = DataStore._get_event_from_row.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) - _get_state_for_groups = DataStore._get_state_for_groups.__func__ - _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ - _get_events_around_txn = DataStore._get_events_around_txn.__func__ - _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ - - get_backfill_events = DataStore.get_backfill_events.__func__ - _get_backfill_events = DataStore._get_backfill_events.__func__ - get_missing_events = DataStore.get_missing_events.__func__ - _get_missing_events = DataStore._get_missing_events.__func__ - - get_auth_chain = DataStore.get_auth_chain.__func__ - get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ - _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ - - get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ - - get_forward_extremeties_for_room = ( - DataStore.get_forward_extremeties_for_room.__func__ - ) - _get_forward_extremeties_for_room = ( - EventFederationStore.__dict__["_get_forward_extremeties_for_room"] - ) - - get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ - - get_federation_out_pos = DataStore.get_federation_out_pos.__func__ - update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 819ed62881..456a14cd5c 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage.filtering import FilteringStore +from ._base import BaseSlavedStore + class SlavedFilteringStore(BaseSlavedStore): def __init__(self, db_conn, hs): diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py new file mode 100644 index 0000000000..5777f07c8d --- /dev/null +++ b/synapse/replication/slave/storage/groups.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + + +class SlavedGroupServerStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(db_conn, hs) + + self.hs = hs + + self._group_updates_id_gen = SlavedIdTracker( + db_conn, "local_group_updates", "stream_id", + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + ) + + get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ + get_group_stream_token = DataStore.get_group_stream_token.__func__ + get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + + def stream_positions(self): + result = super(SlavedGroupServerStore, self).stream_positions() + result["groups"] = self._group_updates_id_gen.get_current_token() + return result + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "groups": + self._group_updates_id_gen.advance(token) + for row in rows: + self._group_updates_stream_cache.entity_has_changed( + row.user_id, token + ) + + return super(SlavedGroupServerStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index dd2ae49e48..05ed168463 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.keys import KeyStore +from ._base import BaseSlavedStore + class SlavedKeyStore(BaseSlavedStore): _get_server_verify_key = KeyStore.__dict__[ diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index cfb9280181..80b744082a 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.storage import DataStore from synapse.storage.presence import PresenceStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py new file mode 100644 index 0000000000..46c28d4171 --- /dev/null +++ b/synapse/replication/slave/storage/profile.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.storage.profile import ProfileWorkerStore + + +class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 83e880fdd2..f0200c1e98 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,31 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .events import SlavedEventStore +from synapse.storage.push_rule import PushRulesWorkerStore + from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from .events import SlavedEventStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,6 +33,9 @@ class SlavedPushRuleStore(SlavedEventStore): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 4e8d68ece9..3b2213c0d4 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.pusher import PusherWorkerStore + from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore - -class SlavedPusherStore(BaseSlavedStore): +class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) @@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore): extra_tables=[("deleted_pushers", "stream_id")], ) - get_all_pushers = DataStore.get_all_pushers.__func__ - get_pushers_by = DataStore.get_pushers_by.__func__ - get_pushers_by_app_id_and_pushkey = ( - DataStore.get_pushers_by_app_id_and_pushkey.__func__ - ) - _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ - def stream_positions(self): result = super(SlavedPusherStore, self).stream_positions() result["pushers"] = self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index b371574ece..ed12342f40 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.receipts import ReceiptsWorkerStore + from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.receipts import ReceiptsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the # DataStore or are cached and don't have cache invalidation logic. @@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache # the method descriptor on the DataStore and chuck them into our class. -class SlavedReceiptsStore(BaseSlavedStore): +class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedReceiptsStore, self).__init__(db_conn, hs) - + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() - ) - - get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] - get_linearized_receipts_for_room = ( - ReceiptsStore.__dict__["get_linearized_receipts_for_room"] - ) - _get_linearized_receipts_for_rooms = ( - ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] - ) - get_last_receipt_event_id_for_user = ( - ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] - ) - - get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ - get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + super(SlavedReceiptsStore, self).__init__(db_conn, hs) - get_linearized_receipts_for_rooms = ( - DataStore.get_linearized_receipts_for_rooms.__func__ - ) + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() @@ -67,10 +49,12 @@ class SlavedReceiptsStore(BaseSlavedStore): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self.get_linearized_receipts_for_room.invalidate_many((room_id,)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, token, rows): if stream_name == "receipts": diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index e27c7332d2..408d91df1c 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,21 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.registration import RegistrationStore - +from synapse.storage.registration import RegistrationWorkerStore -class SlavedRegistrationStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedRegistrationStore, self).__init__(db_conn, hs) +from ._base import BaseSlavedStore - # TODO: use the cached version and invalidate deleted tokens - get_user_by_access_token = RegistrationStore.__dict__[ - "get_user_by_access_token" - ] - _query_for_auth = DataStore._query_for_auth.__func__ - get_user_by_id = RegistrationStore.__dict__[ - "get_user_by_id" - ] +class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f510384033..0cb474928c 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,33 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.room import RoomWorkerStore + from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.room import RoomStore from ._slaved_id_tracker import SlavedIdTracker -class RoomStore(BaseSlavedStore): +class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(RoomStore, self).__init__(db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) - get_public_room_ids = DataStore.get_public_room_ids.__func__ - get_current_public_room_stream_id = ( - DataStore.get_current_public_room_stream_id.__func__ - ) - get_public_room_ids_at_stream_id = ( - RoomStore.__dict__["get_public_room_ids_at_stream_id"] - ) - get_public_room_ids_at_stream_id_txn = ( - DataStore.get_public_room_ids_at_stream_id_txn.__func__ - ) - get_published_at_stream_id_txn = ( - DataStore.get_published_at_stream_id_txn.__func__ - ) - get_public_room_changes = DataStore.get_public_room_changes.__func__ + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() def stream_positions(self): result = super(RoomStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index fbb58f35da..9c9a5eadd9 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.transactions import TransactionStore +from ._base import BaseSlavedStore + class TransactionStore(BaseSlavedStore): get_destination_retry_timings = TransactionStore.__dict__[ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 6d2513c4e2..e592ab57bf 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -15,17 +15,20 @@ """A replication client for use by synapse workers. """ -from twisted.internet import reactor, defer +import logging + +from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from .commands import ( - FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, + FederationAckCommand, + InvalidateCacheCommand, + RemovePusherCommand, UserIpCommand, + UserSyncCommand, ) from .protocol import ClientReplicationStreamProtocol -import logging - logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.server_name = hs.config.server_name self._clock = hs.get_clock() # As self.clock is defined in super class - reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) def startedConnecting(self, connector): logger.info("Connecting to replication: %r", connector.getDestination()) @@ -95,7 +98,7 @@ class ReplicationClientHandler(object): factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - reactor.connectTCP(host, port, factory) + hs.get_reactor().connectTCP(host, port, factory) def on_rdata(self, stream_name, token, rows): """Called when we get new replication data. By default this just pokes diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index a009214e43..f3908df642 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,8 +19,14 @@ allowed to be sent by which side. """ import logging -import ujson as json +import platform +if platform.python_implementation() == "PyPy": + import json + _json_encoder = json.JSONEncoder() +else: + import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -107,7 +113,7 @@ class RdataCommand(Command): return " ".join(( self.stream_name, str(self.token) if self.token is not None else "batch", - json.dumps(self.row), + _json_encoder.encode(self.row), )) @@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join((self.cache_func, json.dumps(self.keys))) + return " ".join(( + self.cache_func, _json_encoder.encode(self.keys), + )) class UserIpCommand(Command): @@ -323,14 +331,18 @@ class UserIpCommand(Command): @classmethod def from_line(cls, line): - user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5) + user_id, jsn = line.split(" ", 1) + + access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen)) + return cls( + user_id, access_token, ip, user_agent, device_id, last_seen + ) def to_line(self): - return " ".join(( - self.user_id, self.access_token, self.ip, self.device_id, - str(self.last_seen), self.user_agent, + return self.user_id + " " + _json_encoder.encode(( + self.access_token, self.ip, self.user_agent, self.device_id, + self.last_seen, )) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 062272f8dd..dec5ac0913 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire:: * connection closed by server * """ +import fcntl +import logging +import struct +from collections import defaultdict + +from six import iteritems, iterkeys + +from prometheus_client import Counter + from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, - ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, - NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, -) -from streams import STREAMS_MAP - +from synapse.metrics import LaterGauge from synapse.util.stringutils import random_string -from synapse.metrics.metric import CounterMetric - -import logging -import synapse.metrics -import struct -import fcntl - -metrics = synapse.metrics.get_metrics_for(__name__) - -connection_close_counter = metrics.register_counter( - "close_reason", labels=["reason_type"], +from .commands import ( + COMMAND_MAP, + VALID_CLIENT_COMMANDS, + VALID_SERVER_COMMANDS, + ErrorCommand, + NameCommand, + PingCommand, + PositionCommand, + RdataCommand, + ReplicateCommand, + ServerCommand, + SyncCommand, + UserSyncCommand, ) +from .streams import STREAMS_MAP +connection_close_counter = Counter( + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = CounterMetric( - "inbound_commands", labels=["command"], - ) - self.outbound_commands_counter = CounterMetric( - "outbound_commands", labels=["command"], - ) + self.inbound_commands_counter = defaultdict(int) + self.outbound_commands_counter = defaultdict(int) def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter.inc(cmd_name) + self.inbound_commands_counter[cmd_name] = ( + self.inbound_commands_counter[cmd_name] + 1) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): becoming full. """ if self.state == ConnectionStates.CLOSED: - logger.info("[%s] Not sending, connection closed", self.id()) + logger.debug("[%s] Not sending, connection closed", self.id()) return if do_buffer and self.state != ConnectionStates.ESTABLISHED: self._queue_command(cmd) return - self.outbound_commands_counter.inc(cmd.NAME) - + self.outbound_commands_counter[cmd.NAME] = ( + self.outbound_commands_counter[cmd.NAME] + 1) string = "%s %s" % (cmd.NAME, cmd.to_line(),) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): - connection_close_counter.inc(reason.type.__name__) + connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.inc(reason.__class__.__name__) + connection_close_counter.labels(reason.__class__.__name__).inc() try: # Remove us from list of connections to be monitored @@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in self.streamer.streams_by_name.iterkeys(): + for stream in iterkeys(self.streamer.streams_by_name): self.subscribe_to_stream(stream, token) else: self.subscribe_to_stream(stream_name, token) @@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): self.replicate(stream_name, token) # Tell the server if we have any users currently syncing (should only @@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_error("Wrong remote") def on_RDATA(self, cmd): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + try: - row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", - self.id(), cmd.stream_name, cmd.row + self.id(), stream_name, cmd.row ) raise if cmd.token is None: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token - self.pending_batches.setdefault(cmd.stream_name, []).append(row) + self.pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(cmd.stream_name, []) + rows = self.pending_batches.pop(stream_name, []) rows.append(row) - self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): self.handler.on_position(cmd.stream_name, cmd.token) @@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # The following simply registers metrics for the replication connections -metrics.register_callback( - "pending_commands", +pending_commands = LaterGauge( + "synapse_replication_tcp_protocol_pending_commands", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) - for p in connected_connections + (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -580,13 +588,13 @@ def transport_buffer_size(protocol): return 0 -metrics.register_callback( - "transport_send_buffer", +transport_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_send_buffer", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) - for p in connected_connections + (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True): return 0 -metrics.register_callback( - "transport_kernel_send_buffer", +tcp_transport_kernel_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_send_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "transport_kernel_read_buffer", +tcp_transport_kernel_read_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_read_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "inbound_commands", +tcp_inbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_inbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.inbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.inbound_commands_counter) }, - labels=["command", "name", "conn_id"], ) -metrics.register_callback( - "outbound_commands", +tcp_outbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_outbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.outbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.outbound_commands_counter) }, - labels=["command", "name", "conn_id"], +) + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 3ea3ca5a6f..611fb66e1d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -15,27 +15,29 @@ """The server side of the replication stream. """ -from twisted.internet import defer, reactor -from twisted.internet.protocol import Factory +import logging -from streams import STREAMS_MAP, FederationStream -from protocol import ServerReplicationStreamProtocol +from six import itervalues -from synapse.util.metrics import Measure, measure_func +from prometheus_client import Counter -import logging -import synapse.metrics +from twisted.internet import defer +from twisted.internet.protocol import Factory +from synapse.metrics import LaterGauge +from synapse.util.metrics import Measure, measure_func + +from .protocol import ServerReplicationStreamProtocol +from .streams import STREAMS_MAP, FederationStream -metrics = synapse.metrics.get_metrics_for(__name__) -stream_updates_counter = metrics.register_counter( - "stream_updates", labels=["stream_name"] -) -user_sync_counter = metrics.register_counter("user_sync") -federation_ack_counter = metrics.register_counter("federation_ack") -remove_pusher_counter = metrics.register_counter("remove_pusher") -invalidate_cache_counter = metrics.register_counter("invalidate_cache") -user_ip_cache_counter = metrics.register_counter("user_ip_cache") +stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", + "", ["stream_name"]) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", + "") +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -69,33 +71,34 @@ class ReplicationStreamer(object): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._server_notices_sender = hs.get_server_notices_sender() # Current connections. self.connections = [] - metrics.register_callback("total_connections", lambda: len(self.connections)) + LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], + lambda: len(self.connections)) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in STREAMS_MAP.itervalues() + stream(hs) for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} - metrics.register_callback( - "connections_per_stream", + LaterGauge( + "synapse_replication_tcp_resource_connections_per_stream", "", + ["stream_name"], lambda: { (stream_name,): len([ conn for conn in self.connections if stream_name in conn.replication_streams ]) for stream_name in self.streams_by_name - }, - labels=["stream_name"], - ) + }) self.federation_sender = None if not hs.config.send_federation: @@ -107,7 +110,7 @@ class ReplicationStreamer(object): self.is_looping = False self.pending_updates = False - reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) def on_shutdown(self): # close all connections on shutdown @@ -160,7 +163,11 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, stream.upto_token ) - updates, current_token = yield stream.get_updates() + try: + updates, current_token = yield stream.get_updates() + except Exception: + logger.info("Failed to handle stream %s", stream.NAME) + raise logger.debug( "Sending %d updates to %d connections", @@ -171,7 +178,7 @@ class ReplicationStreamer(object): logger.info( "Streaming: %s -> %s", stream.NAME, updates[-1][0] ) - stream_updates_counter.inc_by(len(updates), stream.NAME) + stream_updates_counter.labels(stream.NAME).inc(len(updates)) # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do @@ -212,11 +219,12 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") + @defer.inlineCallbacks def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - self.presence_handler.update_external_syncs_row( + yield self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms, ) @@ -240,13 +248,15 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") + @defer.inlineCallbacks def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): """The client saw a user request """ user_ip_cache_counter.inc() - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen, ) + yield self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index fbafe12cc2..55fe701c5c 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -24,11 +24,10 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from twisted.internet import defer -from collections import namedtuple - import logging +from collections import namedtuple +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( "state_key", # str "event_id", # str, optional )) +GroupsStreamRow = namedtuple("GroupsStreamRow", ( + "group_id", # str + "user_id", # str + "type", # str + "content", # dict +)) class Stream(object): @@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream): super(CurrentStateDeltaStream, self).__init__(hs) +class GroupServerStream(Stream): + NAME = "groups" + ROW_TYPE = GroupsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_group_stream_token + self.update_function = store.get_all_groups_changes + + super(GroupServerStream, self).__init__(hs) + + STREAMS_MAP = { stream.NAME: stream for stream in ( @@ -482,5 +500,6 @@ STREAMS_MAP = { TagAccountDataStream, AccountDataStream, CurrentStateDeltaStream, + GroupServerStream, ) } diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3d809d181b..3418f06fd6 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,48 +14,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.client import ( - versions, -) +from six import PY3 +from synapse.http.server import JsonResource +from synapse.rest.client import versions from synapse.rest.client.v1 import ( - room, + admin, + directory, events, - profile, - presence, initial_sync, - directory, - voip, - admin, - pusher, - push_rule, - register as v1_register, login as v1_login, logout, + presence, + profile, + push_rule, + pusher, + room, + voip, ) - from synapse.rest.client.v2_alpha import ( - sync, - filter, account, - register, + account_data, auth, - receipts, - read_marker, + devices, + filter, + groups, keys, - tokenrefresh, - tags, - account_data, - report_event, - openid, notifications, - devices, - thirdparty, + openid, + read_marker, + receipts, + register, + report_event, sendtodevice, + sync, + tags, + thirdparty, + tokenrefresh, user_directory, ) -from synapse.http.server import JsonResource +if not PY3: + from synapse.rest.client.v1_only import ( + register as v1_register, + ) class ClientRestResource(JsonResource): @@ -68,14 +71,22 @@ class ClientRestResource(JsonResource): def register_servlets(client_resource, hs): versions.register_servlets(client_resource) - # "v1" - room.register_servlets(hs, client_resource) + if not PY3: + # "v1" (Python 2 only) + v1_register.register_servlets(hs, client_resource) + + # Deprecated in r0 + initial_sync.register_servlets(hs, client_resource) + room.register_deprecated_servlets(hs, client_resource) + + # Partially deprecated in r0 events.register_servlets(hs, client_resource) - v1_register.register_servlets(hs, client_resource) + + # "v1" + "r0" + room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource) - initial_sync.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource) @@ -102,3 +113,4 @@ class ClientRestResource(JsonResource): thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) + groups.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index fceca2edeb..00b1b3066e 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,37 +17,20 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.api.auth import get_access_token_from_request from synapse.util.async import ObservableDeferred +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) - -def get_transaction_key(request): - """A helper function which returns a transaction key that can be used - with TransactionCache for idempotent requests. - - Idempotency is based on the returned key being the same for separate - requests to the same endpoint. The key is formed from the HTTP request - path and the access_token for the requesting user. - - Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. - Returns: - str: A transaction key - """ - token = get_access_token_from_request(request) - return request.path + "/" + token - - CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, clock): - self.clock = clock + def __init__(self, hs): + self.hs = hs + self.auth = self.hs.get_auth() + self.clock = self.hs.get_clock() self.transactions = { # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) } @@ -55,6 +38,23 @@ class HttpTransactionCache(object): # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + def _get_transaction_key(self, request): + """A helper function which returns a transaction key that can be used + with TransactionCache for idempotent requests. + + Idempotency is based on the returned key being the same for separate + requests to the same endpoint. The key is formed from the HTTP request + path and the access_token for the requesting user. + + Args: + request (twisted.web.http.Request): The incoming request. Must + contain an access_token. + Returns: + str: A transaction key + """ + token = self.auth.get_access_token_from_request(request) + return request.path + "/" + token + def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -63,7 +63,7 @@ class HttpTransactionCache(object): fetch_or_execute """ return self.fetch_or_execute( - get_transaction_key(request), fn, *args, **kwargs + self._get_transaction_key(request), fn, *args, **kwargs ) def fetch_or_execute(self, txn_key, fn, *args, **kwargs): @@ -80,31 +80,30 @@ class HttpTransactionCache(object): Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ - try: - return self.transactions[txn_key][0].observe() - except (KeyError, IndexError): - pass # execute the function instead. - - deferred = fn(*args, **kwargs) - - # if the request fails with a Twisted failure, remove it - # from the transaction map. This is done to ensure that we don't - # cache transient errors like rate-limiting errors, etc. - def remove_from_map(err): - self.transactions.pop(txn_key, None) - return err - deferred.addErrback(remove_from_map) - - # We don't add any other errbacks to the raw deferred, so we ask - # ObservableDeferred to swallow the error. This is fine as the error will - # still be reported to the observers. - observable = ObservableDeferred(deferred, consumeErrors=True) - self.transactions[txn_key] = (observable, self.clock.time_msec()) - return observable.observe() + if txn_key in self.transactions: + observable = self.transactions[txn_key][0] + else: + # execute the function instead. + deferred = run_in_background(fn, *args, **kwargs) + + observable = ObservableDeferred(deferred) + self.transactions[txn_key] = (observable, self.clock.time_msec()) + + # if the request fails with an exception, remove it + # from the transaction map. This is done to ensure that we don't + # cache transient errors like rate-limiting errors, etc. + def remove_from_map(err): + self.transactions.pop(txn_key, None) + # we deliberately do not propagate the error any further, as we + # expect the observers to have reported it. + + deferred.addErrback(remove_from_map) + + return make_deferred_yieldable(observable.observe()) def _cleanup(self): now = self.clock.time_msec() - for key in self.transactions.keys(): + for key in list(self.transactions): ts = self.transactions[key][1] if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period del self.transactions[key] diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 7d786e8de3..99f6c6e3c3 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac +import logging + +from six.moves import http_client + from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + assert_params_in_dict, + parse_integer, + parse_json_object_from_request, + parse_string, +) from synapse.types import UserID, create_requester -from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns -import logging - logger = logging.getLogger(__name__) @@ -55,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class UserRegisterServlet(ClientV1RestServlet): + """ + Attributes: + NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted + nonces (dict[str, int]): The nonces that we will accept. A dict of + nonce to the time it was generated, in int seconds. + """ + PATTERNS = client_path_patterns("/admin/register") + NONCE_TIMEOUT = 60 + + def __init__(self, hs): + super(UserRegisterServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + self.reactor = hs.get_reactor() + self.nonces = {} + self.hs = hs + + def _clear_old_nonces(self): + """ + Clear out old nonces that are older than NONCE_TIMEOUT. + """ + now = int(self.reactor.seconds()) + + for k, v in list(self.nonces.items()): + if now - v > self.NONCE_TIMEOUT: + del self.nonces[k] + + def on_GET(self, request): + """ + Generate a new nonce. + """ + self._clear_old_nonces() + + nonce = self.hs.get_secrets().token_hex(64) + self.nonces[nonce] = int(self.reactor.seconds()) + return (200, {"nonce": nonce.encode('ascii')}) + + @defer.inlineCallbacks + def on_POST(self, request): + self._clear_old_nonces() + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + body = parse_json_object_from_request(request) + + if "nonce" not in body: + raise SynapseError( + 400, "nonce must be specified", errcode=Codes.BAD_JSON, + ) + + nonce = body["nonce"] + + if nonce not in self.nonces: + raise SynapseError( + 400, "unrecognised nonce", + ) + + # Delete the nonce, so it can't be reused, even if it's invalid + del self.nonces[nonce] + + if "username" not in body: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['username'], str) or len(body['username']) > 512): + raise SynapseError(400, "Invalid username") + + username = body["username"].encode("utf-8") + if b"\x00" in username: + raise SynapseError(400, "Invalid username") + + if "password" not in body: + raise SynapseError( + 400, "password must be specified", errcode=Codes.BAD_JSON, + ) + else: + if (not isinstance(body['password'], str) or len(body['password']) > 512): + raise SynapseError(400, "Invalid password") + + password = body["password"].encode("utf-8") + if b"\x00" in password: + raise SynapseError(400, "Invalid password") + + admin = body.get("admin", None) + got_mac = body["mac"] + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=hashlib.sha1, + ) + want_mac.update(nonce) + want_mac.update(b"\x00") + want_mac.update(username) + want_mac.update(b"\x00") + want_mac.update(password) + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") + want_mac = want_mac.hexdigest() + + if not hmac.compare_digest(want_mac, got_mac): + raise SynapseError( + 403, "HMAC incorrect", + ) + + # Reuse the parts of RegisterRestServlet to reduce code duplication + from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) + + (user_id, _) = yield register.registration_handler.register( + localpart=username.lower(), password=password, admin=bool(admin), + generate_token=False, + ) + + result = yield register._create_registration_details(user_id, body) + defer.returnValue((200, result)) + + class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") @@ -95,16 +224,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - before_ts = request.args.get("before_ts", None) - if not before_ts: - raise SynapseError(400, "Missing 'before_ts' arg") - - logger.info("before_ts: %r", before_ts[0]) - - try: - before_ts = int(before_ts[0]) - except Exception: - raise SynapseError(400, "Invalid 'before_ts' arg") + before_ts = parse_integer(request, "before_ts", required=True) + logger.info("before_ts: %r", before_ts) ret = yield self.media_repository.delete_old_remote_media(before_ts) @@ -113,12 +234,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): class PurgeHistoryRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns( - "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" ) def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ super(PurgeHistoryRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.pagination_handler = hs.get_pagination_handler() + self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): @@ -128,20 +255,127 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - yield self.handlers.message_handler.purge_history(room_id, event_id) + body = parse_json_object_from_request(request, allow_empty_body=True) - defer.returnValue((200, {})) + delete_local_events = bool(body.get("delete_local_events", False)) + + # establish the topological ordering we should keep events from. The + # user can provide an event_id in the URL or the request body, or can + # provide a timestamp in the request body. + if event_id is None: + event_id = body.get('purge_up_to_event_id') + + if event_id is not None: + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + token = yield self.store.get_topological_token_for_event(event_id) + + logger.info( + "[purge] purging up to token %s (event_id %s)", + token, event_id, + ) + elif 'purge_up_to_ts' in body: + ts = body['purge_up_to_ts'] + if not isinstance(ts, int): + raise SynapseError( + 400, "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, + ) + + stream_ordering = ( + yield self.store.find_first_stream_ordering_after_ts(ts) + ) + + r = ( + yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + ) + if not r: + logger.warn( + "[purge] purging events not possible: No event found " + "(received_ts %i => stream_ordering %i)", + ts, stream_ordering, + ) + raise SynapseError( + 404, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, + ) + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + logger.info( + "[purge] purging up to token %s (received_ts %i => " + "stream_ordering %i)", + token, ts, stream_ordering, + ) + else: + raise SynapseError( + 400, + "must specify purge_up_to_event_id or purge_up_to_ts", + errcode=Codes.BAD_JSON, + ) + + purge_id = yield self.pagination_handler.start_purge_history( + room_id, token, + delete_local_events=delete_local_events, + ) + + defer.returnValue((200, { + "purge_id": purge_id, + })) + + +class PurgeHistoryStatusRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history_status/(?P<purge_id>[^/]+)" + ) + + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ + super(PurgeHistoryStatusRestServlet, self).__init__(hs) + self.pagination_handler = hs.get_pagination_handler() + + @defer.inlineCallbacks + def on_GET(self, request, purge_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + purge_status = self.pagination_handler.get_purge_status(purge_id) + if purge_status is None: + raise NotFoundError("purge id '%s' not found" % purge_id) + + defer.returnValue((200, purge_status.asdict())) class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - self.store = hs.get_datastore() super(DeactivateAccountRestServlet, self).__init__(hs) + self._deactivate_account_handler = hs.get_deactivate_account_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): + body = parse_json_object_from_request(request, allow_empty_body=True) + erase = body.get("erase", False) + if not isinstance(erase, bool): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'erase' must be a boolean, if given", + Codes.BAD_JSON, + ) + UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(requester.user) @@ -149,12 +383,9 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(target_user_id) - yield self.store.user_delete_threepids(target_user_id) - yield self.store.user_set_password_hash(target_user_id, None) - + yield self._deactivate_account_handler.deactivate_account( + target_user_id, erase, + ) defer.returnValue((200, {})) @@ -168,14 +399,16 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" - " violatation will be blocked." + " violation will be blocked." ) def __init__(self, hs): super(ShutdownRoomRestServlet, self).__init__(hs) self.store = hs.get_datastore() - self.handlers = hs.get_handlers() self.state = hs.get_state_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() @defer.inlineCallbacks def on_POST(self, request, room_id): @@ -185,17 +418,15 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") content = parse_json_object_from_request(request) - - new_room_user_id = content.get("new_room_user_id") - if not new_room_user_id: - raise SynapseError(400, "Please provide field `new_room_user_id`") + assert_params_in_dict(content, ["new_room_user_id"]) + new_room_user_id = content["new_room_user_id"] room_creator_requester = create_requester(new_room_user_id) message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = yield self.handlers.room_creation_handler.create_room( + info = yield self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -208,8 +439,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): ) new_room_id = info["room_id"] - msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_nonmember_event( + yield self.event_creation_handler.create_and_send_nonmember_event( room_creator_requester, { "type": "m.room.message", @@ -235,7 +465,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): logger.info("Kicking %r from %r...", user_id, room_id) target_requester = create_requester(user_id) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -244,9 +474,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): ratelimit=False ) - yield self.handlers.room_member_handler.forget(target_requester.user, room_id) + yield self.room_member_handler.forget(target_requester.user, room_id) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=new_room_id, @@ -294,9 +524,30 @@ class QuarantineMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"num_quarantined": num_quarantined})) +class ListMediaInRoom(ClientV1RestServlet): + """Lists all of the media in a given room. + """ + PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media") + + def __init__(self, hs): + super(ListMediaInRoom, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + if not is_admin: + raise AuthError(403, "You are not a server admin") + + local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) + + defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) + + class ResetPasswordRestServlet(ClientV1RestServlet): """Post request to allow an administrator reset password for a user. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ @user:to_reset_password?access_token=admin_access_token @@ -314,12 +565,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet): super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() + self._set_password_handler = hs.get_set_password_handler() @defer.inlineCallbacks def on_POST(self, request, target_user_id): """Post request to allow an administrator reset password for a user. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. """ UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -329,13 +580,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet): raise AuthError(403, "You are not a server admin") params = parse_json_object_from_request(request) + assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] - if not new_password: - raise SynapseError(400, "Missing 'new_password' arg") logger.info("new_password: %r", new_password) - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( target_user_id, new_password, requester ) defer.returnValue((200, {})) @@ -343,7 +593,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet): class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Get request to get specific number of users from Synapse. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ @admin:user?access_token=admin_access_token&start=0&limit=10 @@ -362,7 +612,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, target_user_id): """Get request to get specific number of users from Synapse. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. """ target_user = UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -379,12 +629,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): raise SynapseError(400, "Can only users a local user") order = "name" # order by name in user table - start = request.args.get("start")[0] - limit = request.args.get("limit")[0] - if not limit: - raise SynapseError(400, "Missing 'limit' arg") - if not start: - raise SynapseError(400, "Missing 'start' arg") + start = parse_integer(request, "start", required=True) + limit = parse_integer(request, "limit", required=True) + logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate( @@ -395,7 +642,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, target_user_id): """Post request to get specific number of users from Synapse.. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ @admin:user?access_token=admin_access_token @@ -416,12 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): order = "name" # order by name in user table params = parse_json_object_from_request(request) + assert_params_in_dict(params, ["limit", "start"]) limit = params['limit'] start = params['start'] - if not limit: - raise SynapseError(400, "Missing 'limit' arg") - if not start: - raise SynapseError(400, "Missing 'start' arg") logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate( @@ -433,7 +677,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): class SearchUsersRestServlet(ClientV1RestServlet): """Get request to search user table for specific users according to search term. - This need a user have a administrator access in Synapse. + This needs user to have administrator access in Synapse. Example: http://localhost:8008/_matrix/client/api/v1/admin/search_users/ @admin:user?access_token=admin_access_token&term=alice @@ -453,7 +697,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): def on_GET(self, request, target_user_id): """Get request to search user table for specific users according to search term. - This need a user have a administrator access in Synapse. + This needs user to have a administrator access in Synapse. """ target_user = UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) @@ -469,10 +713,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") - term = request.args.get("term")[0] - if not term: - raise SynapseError(400, "Missing 'term' arg") - + term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = yield self.handlers.admin_handler.search_users( @@ -484,6 +725,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) + PurgeHistoryStatusRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) @@ -492,3 +734,5 @@ def register_servlets(hs, http_server): SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) + ListMediaInRoom(hs).register(http_server) + UserRegisterServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..c77d7aba68 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -16,14 +16,12 @@ """This module contains base REST classes for constructing client v1 servlets. """ -from synapse.http.servlet import RestServlet -from synapse.api.urls import CLIENT_PREFIX -from synapse.rest.client.transactions import HttpTransactionCache - -import re - import logging +import re +from synapse.api.urls import CLIENT_PREFIX +from synapse.http.servlet import RestServlet +from synapse.rest.client.transactions import HttpTransactionCache logger = logging.getLogger(__name__) @@ -52,6 +50,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +61,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index f15aa5c13f..69dcd618cb 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -14,17 +14,16 @@ # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError, Codes -from synapse.types import RoomAlias +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.http.servlet import parse_json_object_from_request +from synapse.types import RoomAlias from .base import ClientV1RestServlet, client_path_patterns -import logging - - logger = logging.getLogger(__name__) @@ -53,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, "Missing room_id key", + raise SynapseError(400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON) logger.debug("Got content: %s", content) - - room_alias = RoomAlias.from_string(room_alias) - logger.debug("Got room name: %s", room_alias.to_string()) room_id = content["room_id"] @@ -93,7 +91,7 @@ class ClientDirectoryServer(ClientV1RestServlet): ) except SynapseError as e: raise e - except: + except Exception: logger.exception("Failed to create association") raise except AuthError: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 701b6f549b..b70c9c2806 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -14,15 +14,15 @@ # limitations under the License. """This module contains REST servlets to do with event streaming, /events.""" +import logging + from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_patterns from synapse.events.utils import serialize_event +from synapse.streams.config import PaginationConfig -import logging - +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 478e21eea8..fd5f85b53e 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -15,7 +15,9 @@ from twisted.internet import defer +from synapse.http.servlet import parse_boolean from synapse.streams.config import PaginationConfig + from .base import ClientV1RestServlet, client_path_patterns @@ -32,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) - include_archived = request.args.get("archived", None) == ["true"] + include_archived = parse_boolean(request, "archived", default=False) content = yield self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a43410fb37..cb85fa1436 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -13,30 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import urllib +import xml.etree.ElementTree as ET + +from six.moves.urllib import parse as urlparse + +from canonicaljson import json +from saml2 import BINDING_HTTP_POST, config +from saml2.client import Saml2Client + from twisted.internet import defer +from twisted.web.client import PartialDownloadError -from synapse.api.errors import SynapseError, LoginError, Codes -from synapse.types import UserID +from synapse.api.errors import Codes, LoginError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import parse_json_object_from_request +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn from .base import ClientV1RestServlet, client_path_patterns -import simplejson as json -import urllib -import urlparse - -import logging -from saml2 import BINDING_HTTP_POST -from saml2 import config -from saml2.client import Saml2Client - -import xml.etree.ElementTree as ET - -from twisted.web.client import PartialDownloadError - - logger = logging.getLogger(__name__) @@ -85,7 +82,6 @@ def login_id_thirdparty_from_phone(identifier): class LoginRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login$") - PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" @@ -94,7 +90,6 @@ class LoginRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url - self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret @@ -121,8 +116,10 @@ class LoginRestServlet(ClientV1RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.password_enabled: - flows.append({"type": LoginRestServlet.PASS_TYPE}) + + flows.extend(( + {"type": t} for t in self.auth_handler.get_supported_login_types() + )) return (200, {"flows": flows}) @@ -133,14 +130,8 @@ class LoginRestServlet(ClientV1RestServlet): def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if login_submission["type"] == LoginRestServlet.PASS_TYPE: - if not self.password_enabled: - raise SynapseError(400, "Password login has been disabled.") - - result = yield self.do_password_login(login_submission) - defer.returnValue(result) - elif self.saml2_enabled and (login_submission["type"] == - LoginRestServlet.SAML2_TYPE): + if self.saml2_enabled and (login_submission["type"] == + LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: relay_state = "&RelayState=" + urllib.quote( @@ -157,15 +148,31 @@ class LoginRestServlet(ClientV1RestServlet): result = yield self.do_token_login(login_submission) defer.returnValue(result) else: - raise SynapseError(400, "Bad login type.") + result = yield self._do_other_login(login_submission) + defer.returnValue(result) except KeyError: raise SynapseError(400, "Missing JSON keys.") @defer.inlineCallbacks - def do_password_login(self, login_submission): - if "password" not in login_submission: - raise SynapseError(400, "Missing parameter: password") + def _do_other_login(self, login_submission): + """Handle non-token/saml/jwt logins + + Args: + login_submission: + Returns: + (int, object): HTTP code/response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get('identifier'), + login_submission.get('medium'), + login_submission.get('address'), + login_submission.get('user'), + ) login_submission_legacy_convert(login_submission) if "identifier" not in login_submission: @@ -181,19 +188,25 @@ class LoginRestServlet(ClientV1RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - if 'medium' not in identifier or 'address' not in identifier: + address = identifier.get('address') + medium = identifier.get('medium') + + if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - address = identifier['address'] - if identifier['medium'] == 'email': + if medium == 'email': # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - identifier['medium'], address + medium, address, ) if not user_id: + logger.warn( + "unknown 3pid identifier medium %s, address %r", + medium, address, + ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = { @@ -208,30 +221,29 @@ class LoginRestServlet(ClientV1RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - user_id = identifier["user"] - - if not user_id.startswith('@'): - user_id = UserID.create( - user_id, self.hs.hostname - ).to_string() - auth_handler = self.auth_handler - user_id = yield auth_handler.validate_password_login( - user_id=user_id, - password=login_submission["password"], + canonical_user_id, callback = yield auth_handler.validate_login( + identifier["user"], + login_submission, + ) + + device_id = yield self._register_device( + canonical_user_id, login_submission, ) - device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name"), + canonical_user_id, device_id, ) + result = { - "user_id": user_id, # may have changed + "user_id": canonical_user_id, "access_token": access_token, "home_server": self.hs.hostname, "device_id": device_id, } + if callback is not None: + yield callback(result) + defer.returnValue((200, result)) @defer.inlineCallbacks @@ -244,7 +256,6 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device(user_id, login_submission) access_token = yield auth_handler.get_access_token_for_user_id( user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed @@ -278,7 +289,7 @@ class LoginRestServlet(ClientV1RestServlet): if user is None: raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: @@ -287,7 +298,6 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, - login_submission.get("initial_device_display_name"), ) result = { @@ -444,7 +454,7 @@ class CasTicketServlet(ClientV1RestServlet): if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - user_id = UserID.create(user, self.hs.hostname).to_string() + user_id = UserID(user, self.hs.hostname).to_string() auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if not registered_user_id: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1358d0acab..430c692336 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.auth import get_access_token_from_request +from synapse.api.errors import AuthError from .base import ClientV1RestServlet, client_path_patterns -import logging - - logger = logging.getLogger(__name__) @@ -30,15 +29,33 @@ class LogoutRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutRestServlet, self).__init__(hs) - self.store = hs.get_datastore() + self._auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @defer.inlineCallbacks def on_POST(self, request): - access_token = get_access_token_from_request(request) - yield self.store.delete_access_token(access_token) + try: + requester = yield self.auth.get_user_by_req(request) + except AuthError: + # this implies the access token has already been deleted. + defer.returnValue((401, { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired" + })) + else: + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = self._auth.get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) + else: + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) + defer.returnValue((200, {})) @@ -47,8 +64,9 @@ class LogoutAllRestServlet(ClientV1RestServlet): def __init__(self, hs): super(LogoutAllRestServlet, self).__init__(hs) - self.store = hs.get_datastore() self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): return (200, {}) @@ -57,7 +75,13 @@ class LogoutAllRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.store.user_delete_access_tokens(user_id) + + # first delete all of the user's devices + yield self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. + yield self._auth_handler.delete_access_tokens_for_user(user_id) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 47b2dc45e7..a14f0c807e 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -15,15 +15,18 @@ """ This module contains REST servlets to do with presence: /presence/<paths> """ +import logging + +from six import string_types + from twisted.internet import defer -from synapse.api.errors import SynapseError, AuthError -from synapse.types import UserID +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import parse_json_object_from_request -from .base import ClientV1RestServlet, client_path_patterns +from synapse.types import UserID -import logging +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) @@ -71,14 +74,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], basestring): + if not isinstance(state["status_msg"], string_types): raise SynapseError(400, "status_msg must be a string.") if content: raise KeyError() except SynapseError as e: raise e - except: + except Exception: raise SynapseError(400, "Unable to parse state") yield self.presence_handler.set_state(user, state) @@ -129,7 +132,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if "invite" in content: for u in content["invite"]: - if not isinstance(u, basestring): + if not isinstance(u, string_types): raise SynapseError(400, "Bad invite value.") if len(u) == 0: continue @@ -140,7 +143,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if "drop" in content: for u in content["drop"]: - if not isinstance(u, basestring): + if not isinstance(u, string_types): raise SynapseError(400, "Bad drop value.") if len(u) == 0: continue diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1a5045c9ec..a23edd8fe5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,9 +16,10 @@ """ This module contains REST servlets to do with profile: /profile/<paths> """ from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns -from synapse.types import UserID from synapse.http.servlet import parse_json_object_from_request +from synapse.types import UserID + +from .base import ClientV1RestServlet, client_path_patterns class ProfileDisplaynameRestServlet(ClientV1RestServlet): @@ -26,13 +27,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileDisplaynameRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) @@ -52,10 +53,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): try: new_name = content["displayname"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -69,13 +70,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileAvatarURLRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) @@ -94,10 +95,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) try: new_name = content["avatar_url"] - except: + except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_avatar_url( + yield self.profile_handler.set_avatar_url( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -111,16 +112,16 @@ class ProfileRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6bb4821ec6..6e95d9bec2 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -16,16 +16,18 @@ from twisted.internet import defer from synapse.api.errors import ( - SynapseError, UnrecognizedRequestError, NotFoundError, StoreError + NotFoundError, + StoreError, + SynapseError, + UnrecognizedRequestError, ) -from .base import ClientV1RestServlet, client_path_patterns -from synapse.storage.push_rule import ( - InconsistentRuleException, RuleNotFoundException -) -from synapse.push.clientformat import format_push_rules_for_user +from synapse.http.servlet import parse_json_value_from_request, parse_string from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.http.servlet import parse_json_value_from_request +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException + +from .base import ClientV1RestServlet, client_path_patterns class PushRuleRestServlet(ClientV1RestServlet): @@ -73,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - before = request.args.get("before", None) + before = parse_string(request, "before") if before: - before = _namespaced_rule_id(spec, before[0]) + before = _namespaced_rule_id(spec, before) - after = request.args.get("after", None) + after = parse_string(request, "after") if after: - after = _namespaced_rule_id(spec, after[0]) + after = _namespaced_rule_id(spec, after) try: yield self.store.add_push_rule( diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 9a2ed6ed88..182a68b1e2 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import SynapseError, Codes -from synapse.push import PusherConfigException +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.http.server import finish_request from synapse.http.servlet import ( - parse_json_object_from_request, parse_string, RestServlet + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, ) -from synapse.http.server import finish_request -from synapse.api.errors import StoreError +from synapse.push import PusherConfigException from .base import ClientV1RestServlet, client_path_patterns -import logging - logger = logging.getLogger(__name__) @@ -73,6 +75,7 @@ class PushersSetRestServlet(ClientV1RestServlet): def __init__(self, hs): super(PushersSetRestServlet, self).__init__(hs) self.notifier = hs.get_notifier() + self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks def on_POST(self, request): @@ -81,25 +84,19 @@ class PushersSetRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - pusher_pool = self.hs.get_pusherpool() - if ('pushkey' in content and 'app_id' in content and 'kind' in content and content['kind'] is None): - yield pusher_pool.remove_pusher( + yield self.pusher_pool.remove_pusher( content['app_id'], content['pushkey'], user_id=user.to_string() ) defer.returnValue((200, {})) - reqd = ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] - missing = [] - for i in reqd: - if i not in content: - missing.append(i) - if len(missing): - raise SynapseError(400, "Missing parameters: " + ','.join(missing), - errcode=Codes.MISSING_PARAM) + assert_params_in_dict( + content, + ['kind', 'app_id', 'app_display_name', + 'device_display_name', 'pushkey', 'lang', 'data'] + ) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("Got pushers request with body: %r", content) @@ -109,14 +106,14 @@ class PushersSetRestServlet(ClientV1RestServlet): append = content['append'] if not append: - yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content['app_id'], pushkey=content['pushkey'], not_user_id=user.to_string() ) try: - yield pusher_pool.add_pusher( + yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content['kind'], @@ -148,10 +145,11 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" def __init__(self, hs): - super(RestServlet, self).__init__() + super(PushersRemoveRestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() + self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks def on_GET(self, request): @@ -161,10 +159,8 @@ class PushersRemoveRestServlet(RestServlet): app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) - pusher_pool = self.hs.get_pusherpool() - try: - yield pusher_pool.remove_pusher( + yield self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string(), @@ -178,7 +174,6 @@ class PushersRemoveRestServlet(RestServlet): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Server", self.hs.version_string) request.setHeader(b"Content-Length", b"%d" % ( len(PushersRemoveRestServlet.SUCCESS_HTML), )) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cd388770c8..b7bd878c90 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,28 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ +import logging + +from six.moves.urllib import parse as urlparse + +from canonicaljson import json + from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns -from synapse.api.errors import SynapseError, Codes, AuthError -from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter -from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID -from synapse.events.utils import serialize_event, format_event_for_client_v2 +from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.http.servlet import ( - parse_json_object_from_request, parse_string, parse_integer + assert_params_in_dict, + parse_integer, + parse_json_object_from_request, + parse_string, ) +from synapse.streams.config import PaginationConfig +from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID -import logging -import urllib -import ujson as json +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) @@ -39,7 +46,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomCreateRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self._room_creation_handler = hs.get_room_creation_handler() def register(self, http_server): PATTERNS = "/createRoom" @@ -62,8 +69,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.room_creation_handler - info = yield handler.create_room( + info = yield self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -82,6 +88,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.event_creation_hander = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.message_handler = hs.get_message_handler() def register(self, http_server): # /room/$roomid/state/$eventtype @@ -116,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): format = parse_string(request, "format", default="content", allowed_values=["content", "event"]) - msg_handler = self.handlers.message_handler + msg_handler = self.message_handler data = yield msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, @@ -154,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.handlers.room_member_handler.update_membership( + event = yield self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -162,16 +171,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): content=content, ) else: - msg_handler = self.handlers.message_handler - event, context = yield msg_handler.create_event( + event = yield self.event_creation_hander.create_and_send_nonmember_event( requester, event_dict, - token_id=requester.access_token_id, txn_id=txn_id, ) - yield msg_handler.send_nonmember_event(requester, event, context) - ret = {} if event: ret = {"event_id": event.event_id} @@ -183,7 +188,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.event_creation_hander = hs.get_event_creation_handler() def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -195,15 +200,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_nonmember_event( + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if 'ts' in request.args and requester.app_service: + event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + + event = yield self.event_creation_hander.create_and_send_nonmember_event( requester, - { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - }, + event_dict, txn_id=txn_id, ) @@ -222,7 +231,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinRoomAliasServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -238,7 +247,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} @@ -247,10 +256,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_id = room_identifier try: remote_room_hosts = request.args["server_name"] - except: + except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): - handler = self.handlers.room_member_handler + handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() @@ -259,7 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): room_identifier, )) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -369,14 +378,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomMemberListRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.message_handler = hs.get_message_handler() @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) requester = yield self.auth.get_user_by_req(request) - handler = self.handlers.message_handler - events = yield handler.get_state_events( + events = yield self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), ) @@ -398,22 +406,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinedRoomMemberListRestServlet, self).__init__(hs) - self.state = hs.get_state_handler() + self.message_handler = hs.get_message_handler() @defer.inlineCallbacks def on_GET(self, request, room_id): - yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.message_handler.get_joined_members( + requester, room_id, + ) defer.returnValue((200, { - "joined": { - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, - } - for user_id, profile in users_with_profile.iteritems() - } + "joined": users_with_profile, })) @@ -423,7 +427,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomMessageListRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.pagination_handler = hs.get_pagination_handler() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -432,14 +436,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args - filter_bytes = request.args.get("filter", None) + filter_bytes = parse_string(request, "filter") if filter_bytes: - filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) else: event_filter = None - handler = self.handlers.message_handler - msgs = yield handler.get_messages( + msgs = yield self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -456,14 +459,13 @@ class RoomStateRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomStateRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.message_handler = hs.get_message_handler() @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - handler = self.handlers.message_handler # Get all the current state for this room - events = yield handler.get_state_events( + events = yield self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -491,23 +493,45 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): defer.returnValue((200, content)) -class RoomEventContext(ClientV1RestServlet): +class RoomEventServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" + ) + + def __init__(self, hs): + super(RoomEventServlet, self).__init__(hs) + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + event = yield self.event_handler.get_event(requester.user, event_id) + + time_now = self.clock.time_msec() + if event: + defer.returnValue((200, serialize_event(event, time_now))) + else: + defer.returnValue((404, "Event not found.")) + + +class RoomEventContextServlet(ClientV1RestServlet): PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" ) def __init__(self, hs): - super(RoomEventContext, self).__init__(hs) + super(RoomEventContextServlet, self).__init__(hs) self.clock = hs.get_clock() - self.handlers = hs.get_handlers() + self.room_context_handler = hs.get_room_context_handler() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - limit = int(request.args.get("limit", [10])[0]) + limit = parse_integer(request, "limit", default=10) - results = yield self.handlers.room_context_handler.get_event_context( + results = yield self.room_context_handler.get_event_context( requester.user, room_id, event_id, @@ -537,7 +561,7 @@ class RoomEventContext(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomForgetRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") @@ -550,7 +574,7 @@ class RoomForgetRestServlet(ClientV1RestServlet): allow_guest=False, ) - yield self.handlers.room_member_handler.forget( + yield self.room_member_handler.forget( user=requester.user, room_id=room_id, ) @@ -568,12 +592,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") + "(?P<membership_action>join|invite|leave|ban|unban|kick)") register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -591,13 +615,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet): try: content = parse_json_object_from_request(request) - except: + except Exception: # Turns out we used to ignore the body entirely, and some clients # cheekily send invalid bodies. content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.handlers.room_member_handler.do_3pid_invite( + yield self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -611,15 +635,14 @@ class RoomMembershipRestServlet(ClientV1RestServlet): target = requester.user if membership_action in ["invite", "ban", "unban", "kick"]: - if "user_id" not in content: - raise SynapseError(400, "Missing user_id key.") + assert_params_in_dict(content, ["user_id"]) target = UserID.from_string(content["user_id"]) event_content = None if 'reason' in content and membership_action in ['kick', 'ban']: event_content = {'reason': content['reason']} - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -629,7 +652,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content=event_content, ) - defer.returnValue((200, {})) + return_value = {} + + if membership_action == "join": + return_value["room_id"] = room_id + + defer.returnValue((200, return_value)) def _has_3pid_invite_keys(self, content): for key in {"id_server", "medium", "address"}: @@ -647,6 +675,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomRedactEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.event_creation_handler = hs.get_event_creation_handler() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") @@ -657,8 +686,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - msg_handler = self.handlers.message_handler - event = yield msg_handler.create_and_send_nonmember_event( + event = yield self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -692,8 +720,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) - room_id = urllib.unquote(room_id) - target_user = UserID.from_string(urllib.unquote(user_id)) + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) @@ -734,7 +762,7 @@ class SearchRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - batch = request.args.get("next_batch", [None])[0] + batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( requester.user, content, @@ -802,9 +830,13 @@ def register_servlets(hs, http_server): RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) - RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) - RoomEventContext(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) + + +def register_deprecated_servlets(hs, http_server): + RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index c43b30b73a..62f4c3d93e 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +import hashlib +import hmac + from twisted.internet import defer from .base import ClientV1RestServlet, client_path_patterns -import hmac -import hashlib -import base64 - - class VoipRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/voip/turnServer$") diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py new file mode 100644 index 0000000000..936f902ace --- /dev/null +++ b/synapse/rest/client/v1_only/__init__.py @@ -0,0 +1,3 @@ +""" +REST APIs that are only used in v1 (the legacy API). +""" diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py new file mode 100644 index 0000000000..9d4db7437c --- /dev/null +++ b/synapse/rest/client/v1_only/base.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +import re + +from synapse.api.urls import CLIENT_PREFIX + + +def v1_only_client_path_patterns(path_regex, include_in_unstable=True): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + list of SRE_Pattern + """ + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + if include_in_unstable: + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + return patterns diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1_only/register.py index ecf7e311a9..3439c3c6d4 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -14,21 +14,20 @@ # limitations under the License. """This module contains REST servlets to do with registration: /register""" +import hmac +import logging +from hashlib import sha1 + from twisted.internet import defer -from synapse.api.errors import SynapseError, Codes -from synapse.api.constants import LoginType -from synapse.api.auth import get_access_token_from_request -from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils -from synapse.http.servlet import parse_json_object_from_request +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.rest.client.v1.base import ClientV1RestServlet from synapse.types import create_requester -from synapse.util.async import run_on_reactor - -from hashlib import sha1 -import hmac -import logging +from .base import v1_only_client_path_patterns logger = logging.getLogger(__name__) @@ -51,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet): handler doesn't have a concept of multi-stages or sessions. """ - PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) + PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False) def __init__(self, hs): """ @@ -66,14 +65,20 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() def on_GET(self, request): + + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] if self.hs.config.enable_registration_captcha: - return ( - 200, - {"flows": [ + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([ { "type": LoginType.RECAPTCHA, "stages": [ @@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet): LoginType.PASSWORD ] }, + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ { "type": LoginType.RECAPTCHA, "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] } - ]} - ) + ]) else: - return ( - 200, - {"flows": [ + # only support the email-only flow if we don't require MSISDN 3PIDs + if require_email or not require_msisdn: + flows.extend([ { "type": LoginType.EMAIL_IDENTITY, "stages": [ LoginType.EMAIL_IDENTITY, LoginType.PASSWORD ] - }, + } + ]) + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([ { "type": LoginType.PASSWORD } - ]} - ) + ]) + return (200, {"flows": flows}) @defer.inlineCallbacks def on_POST(self, request): @@ -111,8 +123,7 @@ class RegisterRestServlet(ClientV1RestServlet): session = (register_json["session"] if "session" in register_json else None) login_type = None - if "type" not in register_json: - raise SynapseError(400, "Missing 'type' key.") + assert_params_in_dict(register_json, ["type"]) try: login_type = register_json["type"] @@ -258,7 +269,6 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_password(self, request, register_json, session): - yield run_on_reactor() if (self.hs.config.enable_registration_captcha and not session[LoginType.RECAPTCHA]): # captcha should've been done by this stage! @@ -298,11 +308,9 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - as_token = get_access_token_from_request(request) - - if "user" not in register_json: - raise SynapseError(400, "Expected 'user' key.") + as_token = self.auth.get_access_token_from_request(request) + assert_params_in_dict(register_json, ["user"]) user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -319,14 +327,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_shared_secret(self, request, register_json, session): - yield run_on_reactor() - - if not isinstance(register_json.get("mac", None), basestring): - raise SynapseError(400, "Expected mac.") - if not isinstance(register_json.get("user", None), basestring): - raise SynapseError(400, "Expected 'user' key.") - if not isinstance(register_json.get("password", None), basestring): - raise SynapseError(400, "Expected 'password' key.") + assert_params_in_dict(register_json, ["mac", "user", "password"]) if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -336,9 +337,9 @@ class RegisterRestServlet(ClientV1RestServlet): admin = register_json.get("admin", None) # Its important to check as we use null bytes as HMAC field separators - if "\x00" in user: + if b"\x00" in user: raise SynapseError(400, "Invalid user") - if "\x00" in password: + if b"\x00" in password: raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not @@ -346,20 +347,20 @@ class RegisterRestServlet(ClientV1RestServlet): got_mac = str(register_json["mac"]) want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), digestmod=sha1, ) want_mac.update(user) - want_mac.update("\x00") + want_mac.update(b"\x00") want_mac.update(password) - want_mac.update("\x00") - want_mac.update("admin" if admin else "notadmin") + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( - localpart=user, + localpart=user.lower(), password=password, admin=bool(admin), ) @@ -379,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet): """Handles user creation via a server-to-server interface """ - PATTERNS = client_path_patterns("/createUser$", releases=()) + PATTERNS = v1_only_client_path_patterns("/createUser$") def __init__(self, hs): super(CreateUserRestServlet, self).__init__(hs) @@ -390,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) app_service = self.store.get_app_service_by_token( access_token ) @@ -409,13 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_create(self, requester, user_json): - yield run_on_reactor() - - if "localpart" not in user_json: - raise SynapseError(400, "Expected 'localpart' key.") - - if "displayname" not in user_json: - raise SynapseError(400, "Expected 'displayname' key.") + assert_params_in_dict(user_json, ["localpart", "displayname"]) localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 1f5bc24cc3..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']['limit'], filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4990b22b9f..eeae466d82 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +14,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from six.moves import http_client from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, ) -from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns - -import logging - +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -44,10 +47,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'email', 'send_attempt' ]) + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( 'email', body['email'] ) @@ -72,13 +80,18 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', ]) msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + existingUid = yield self.datastore.get_user_id_by_threepid( 'msisdn', msisdn ) @@ -99,56 +112,60 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - [LoginType.EMAIL_IDENTITY], - [LoginType.MSISDN], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. + + if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - elif LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() - # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + params = yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id + user_id = requester.user.to_string() else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = None + result, params, _ = yield self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + body, self.hs.get_ip_from_request(request), + ) - if 'new_password' not in params: - raise SynapseError(400, "", Codes.MISSING_PARAM) + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() + # if using email, we must know about the email they're authing with! + threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] - yield self.auth_handler.set_password( + yield self._set_password_handler.set_password( user_id, new_password, requester ) @@ -162,42 +179,39 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_v2_patterns("/account/deactivate$") def __init__(self, hs): + super(DeactivateAccountRestServlet, self).__init__() self.hs = hs - self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - super(DeactivateAccountRestServlet, self).__init__() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) + erase = body.get("erase", False) + if not isinstance(erase, bool): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'erase' must be a boolean, if given", + Codes.BAD_JSON, + ) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) - - user_id = None - requester = None - - if LoginType.PASSWORD in result: - # if using password, they should also be logged in - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - if user_id != result[LoginType.PASSWORD]: - raise LoginError(400, "", Codes.UNKNOWN) - else: - logger.error("Auth succeeded but no known type!", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + requester = yield self.auth.get_user_by_req(request) - # FIXME: Theoretically there is a race here wherein user resets password - # using threepid. - yield self.store.user_delete_access_tokens(user_id) - yield self.store.user_delete_threepids(user_id) - yield self.store.user_set_password_hash(user_id, None) + # allow ASes to dectivate their own users + if requester.app_service: + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), erase, + ) + defer.returnValue((200, {})) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) + yield self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), erase, + ) defer.returnValue((200, {})) @@ -213,15 +227,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) + assert_params_in_dict( + body, + ['id_server', 'client_secret', 'email', 'send_attempt'], + ) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) existingUid = yield self.datastore.get_user_id_by_threepid( 'email', body['email'] @@ -246,21 +260,18 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - - required = [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', - ] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + ]) msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + existingUid = yield self.datastore.get_user_id_by_threepid( 'msisdn', msisdn ) @@ -285,8 +296,6 @@ class ThreepidRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield run_on_reactor() - requester = yield self.auth.get_user_by_req(request) threepids = yield self.datastore.user_get_threepids( @@ -297,8 +306,6 @@ class ThreepidRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) threePidCreds = body.get('threePidCreds') @@ -350,29 +357,40 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) - - required = ['medium', 'address'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if absent: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + assert_params_in_dict(body, ['medium', 'address']) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'] - ) + try: + yield self.auth_handler.delete_threepid( + user_id, body['medium'], body['address'] + ) + except Exception: + # NB. This endpoint should succeed if there is nothing to + # delete, so it should only throw if something is wrong + # that we ought to care about. + logger.exception("Failed to remove threepid") + raise SynapseError(500, "Failed to remove threepid") defer.returnValue((200, {})) +class WhoamiRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/whoami$") + + def __init__(self, hs): + super(WhoamiRestServlet, self).__init__() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + + defer.returnValue((200, {'user_id': requester.user.to_string()})) + + def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) @@ -382,3 +400,4 @@ def register_servlets(hs, http_server): MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 0e0a187efd..371e9aa354 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_patterns - -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError, SynapseError +import logging from twisted.internet import defer -import logging +from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8e5577148f..bd8b5f4afa 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.api.constants import LoginType @@ -23,9 +25,6 @@ from synapse.http.servlet import RestServlet from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) RECAPTCHA_TEMPLATE = """ @@ -129,7 +128,6 @@ class AuthRestServlet(RestServlet): html_bytes = html.encode("utf8") request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Server", self.hs.version_string) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) @@ -175,7 +173,6 @@ class AuthRestServlet(RestServlet): html_bytes = html.encode("utf8") request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Server", self.hs.version_string) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.write(html_bytes) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index b57ba95d24..9b75bb1377 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -17,15 +17,20 @@ import logging from twisted.internet import defer -from synapse.api import constants, errors -from synapse.http import servlet -from ._base import client_v2_patterns +from synapse.api import errors +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) + +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) -class DevicesRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) +class DevicesRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) def __init__(self, hs): """ @@ -46,12 +51,12 @@ class DevicesRestServlet(servlet.RestServlet): defer.returnValue((200, {"devices": devices})) -class DeleteDevicesRestServlet(servlet.RestServlet): +class DeleteDevicesRestServlet(RestServlet): """ API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) + PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -60,31 +65,28 @@ class DeleteDevicesRestServlet(servlet.RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + try: - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: - # deal with older clients which didn't pass a J*DELETESON dict + # DELETE + # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict body = {} else: raise e - if 'devices' not in body: - raise errors.SynapseError( - 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM - ) + assert_params_in_dict(body, ["devices"]) - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_devices( requester.user.to_string(), body['devices'], @@ -92,9 +94,8 @@ class DeleteDevicesRestServlet(servlet.RestServlet): defer.returnValue((200, {})) -class DeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", - releases=[], v2_alpha=False) +class DeviceRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) def __init__(self, hs): """ @@ -116,10 +117,13 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + try: - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: @@ -129,17 +133,12 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ - [constants.LoginType.PASSWORD], - ], body, self.hs.get_ip_from_request(request)) - - if not authed: - defer.returnValue((401, result)) + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request), + ) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id, ) defer.returnValue((200, {})) @@ -147,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet): def on_PUT(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - body = servlet.parse_json_object_from_request(request) + body = parse_json_object_from_request(request) yield self.device_handler.update_device( requester.user.to_string(), device_id, diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index d2b2fd66e6..ae86728879 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -13,17 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError, StoreError, Codes +from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID -from ._base import client_v2_patterns -from ._base import set_timeline_upper_limit - -import logging - +from ._base import client_v2_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) @@ -50,7 +48,7 @@ class GetFilterRestServlet(RestServlet): try: filter_id = int(filter_id) - except: + except Exception: raise SynapseError(400, "Invalid filter_id") try: diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py new file mode 100644 index 0000000000..21e02c07c0 --- /dev/null +++ b/synapse/rest/client/v2_alpha/groups.py @@ -0,0 +1,786 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import GroupID + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class GroupServlet(RestServlet): + """Get the group profile + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") + + def __init__(self, hs): + super(GroupServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + group_description = yield self.groups_handler.get_group_profile( + group_id, + requester_user_id, + ) + + defer.returnValue((200, group_description)) + + @defer.inlineCallbacks + def on_POST(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + yield self.groups_handler.update_group_profile( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, {})) + + +class GroupSummaryServlet(RestServlet): + """Get the full group summary + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") + + def __init__(self, hs): + super(GroupSummaryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + get_group_summary = yield self.groups_handler.get_group_summary( + group_id, + requester_user_id, + ) + + defer.returnValue((200, get_group_summary)) + + +class GroupSummaryRoomsCatServlet(RestServlet): + """Update/delete a rooms entry in the summary. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/categories/(?P<category_id>[^/]+))?" + "/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryRoomsCatServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_room( + group_id, requester_user_id, + room_id=room_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoryServlet(RestServlet): + """Get/add/update/delete a group category + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupCategoryServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_category( + group_id, requester_user_id, + category_id=category_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, category_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_category( + group_id, requester_user_id, + category_id=category_id, + ) + + defer.returnValue((200, resp)) + + +class GroupCategoriesServlet(RestServlet): + """Get all group categories + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/categories/$" + ) + + def __init__(self, hs): + super(GroupCategoriesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_categories( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupRoleServlet(RestServlet): + """Get/add/update/delete a group role + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" + ) + + def __init__(self, hs): + super(GroupRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, category)) + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_role( + group_id, requester_user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_role( + group_id, requester_user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRolesServlet(RestServlet): + """Get all group roles + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/roles/$" + ) + + def __init__(self, hs): + super(GroupRolesServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = yield self.groups_handler.get_group_roles( + group_id, requester_user_id, + ) + + defer.returnValue((200, category)) + + +class GroupSummaryUsersRoleServlet(RestServlet): + """Update/delete a user's entry in the summary. + + Matches both: + - /groups/:group/summary/users/:room_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/summary" + "(/roles/(?P<role_id>[^/]+))?" + "/users/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupSummaryUsersRoleServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + resp = yield self.groups_handler.update_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, role_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + resp = yield self.groups_handler.delete_group_summary_user( + group_id, requester_user_id, + user_id=user_id, + role_id=role_id, + ) + + defer.returnValue((200, resp)) + + +class GroupRoomServlet(RestServlet): + """Get all rooms in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") + + def __init__(self, hs): + super(GroupRoomServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupUsersServlet(RestServlet): + """Get all users in a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") + + def __init__(self, hs): + super(GroupUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + + defer.returnValue((200, result)) + + +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") + + def __init__(self, hs): + super(GroupInvitedUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_invited_users_in_group( + group_id, + requester_user_id, + ) + + defer.returnValue((200, result)) + + +class GroupSettingJoinPolicyServlet(RestServlet): + """Set group join policy + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") + + def __init__(self, hs): + super(GroupSettingJoinPolicyServlet, self).__init__() + self.auth = hs.get_auth() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + + result = yield self.groups_handler.set_group_join_policy( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + +class GroupCreateServlet(RestServlet): + """Create a group + """ + PATTERNS = client_v2_patterns("/create_group$") + + def __init__(self, hs): + super(GroupCreateServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.server_name = hs.hostname + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + # TODO: Create group on remote server + content = parse_json_object_from_request(request) + localpart = content.pop("localpart") + group_id = GroupID(localpart, self.server_name).to_string() + + result = yield self.groups_handler.create_group( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsServlet(RestServlet): + """Add a room to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.add_room_to_group( + group_id, requester_user_id, room_id, content, + ) + + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_DELETE(self, request, group_id, room_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.remove_room_from_group( + group_id, requester_user_id, room_id, + ) + + defer.returnValue((200, result)) + + +class GroupAdminRoomsConfigServlet(RestServlet): + """Update the config of a room in a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" + "/config/(?P<config_key>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminRoomsConfigServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, room_id, config_key): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersInviteServlet(RestServlet): + """Invite a user to the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.store = hs.get_datastore() + self.is_mine_id = hs.is_mine_id + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + config = content.get("config", {}) + result = yield self.groups_handler.invite( + group_id, user_id, requester_user_id, config, + ) + + defer.returnValue((200, result)) + + +class GroupAdminUsersKickServlet(RestServlet): + """Kick a user from the group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(GroupAdminUsersKickServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id, user_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfLeaveServlet(RestServlet): + """Leave a joined group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/leave$" + ) + + def __init__(self, hs): + super(GroupSelfLeaveServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.remove_user_from_group( + group_id, requester_user_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfJoinServlet(RestServlet): + """Attempt to join a group, or knock + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/join$" + ) + + def __init__(self, hs): + super(GroupSelfJoinServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.join_group( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfAcceptInviteServlet(RestServlet): + """Accept a group invite + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/accept_invite$" + ) + + def __init__(self, hs): + super(GroupSelfAcceptInviteServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + result = yield self.groups_handler.accept_invite( + group_id, requester_user_id, content, + ) + + defer.returnValue((200, result)) + + +class GroupSelfUpdatePublicityServlet(RestServlet): + """Update whether we publicise a users membership of a group + """ + PATTERNS = client_v2_patterns( + "/groups/(?P<group_id>[^/]*)/self/update_publicity$" + ) + + def __init__(self, hs): + super(GroupSelfUpdatePublicityServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + publicise = content["publicise"] + yield self.store.update_group_publicity( + group_id, requester_user_id, publicise, + ) + + defer.returnValue((200, {})) + + +class PublicisedGroupsForUserServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups/(?P<user_id>[^/]*)$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + yield self.auth.get_user_by_req(request, allow_guest=True) + + result = yield self.groups_handler.get_publicised_groups_for_user( + user_id + ) + + defer.returnValue((200, result)) + + +class PublicisedGroupsForUsersServlet(RestServlet): + """Get the list of groups a user is advertising + """ + PATTERNS = client_v2_patterns( + "/publicised_groups$" + ) + + def __init__(self, hs): + super(PublicisedGroupsForUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + user_ids = content["user_ids"] + + result = yield self.groups_handler.bulk_get_publicised_groups( + user_ids + ) + + defer.returnValue((200, result)) + + +class GroupsForUserServlet(RestServlet): + """Get all groups the logged in user is joined to + """ + PATTERNS = client_v2_patterns( + "/joined_groups$" + ) + + def __init__(self, hs): + super(GroupsForUserServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = yield self.groups_handler.get_joined_groups(requester_user_id) + + defer.returnValue((200, result)) + + +def register_servlets(hs, http_server): + GroupServlet(hs).register(http_server) + GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) + GroupUsersServlet(hs).register(http_server) + GroupRoomServlet(hs).register(http_server) + GroupSettingJoinPolicyServlet(hs).register(http_server) + GroupCreateServlet(hs).register(http_server) + GroupAdminRoomsServlet(hs).register(http_server) + GroupAdminRoomsConfigServlet(hs).register(http_server) + GroupAdminUsersInviteServlet(hs).register(http_server) + GroupAdminUsersKickServlet(hs).register(http_server) + GroupSelfLeaveServlet(hs).register(http_server) + GroupSelfJoinServlet(hs).register(http_server) + GroupSelfAcceptInviteServlet(hs).register(http_server) + GroupsForUserServlet(hs).register(http_server) + GroupCategoryServlet(hs).register(http_server) + GroupCategoriesServlet(hs).register(http_server) + GroupSummaryRoomsCatServlet(hs).register(http_server) + GroupRoleServlet(hs).register(http_server) + GroupRolesServlet(hs).register(http_server) + GroupSelfUpdatePublicityServlet(hs).register(http_server) + GroupSummaryUsersRoleServlet(hs).register(http_server) + PublicisedGroupsForUserServlet(hs).register(http_server) + PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 6a3cfe84f8..8486086b51 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -19,10 +19,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, parse_integer + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, ) -from synapse.http.servlet import parse_string from synapse.types import StreamToken + from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -53,8 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -128,10 +130,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns( - "/keys/query$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/query$") def __init__(self, hs): """ @@ -160,10 +159,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns( - "/keys/changes$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/changes$") def __init__(self, hs): """ @@ -188,13 +184,11 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - changed = yield self.device_handler.get_user_ids_changed( + results = yield self.device_handler.get_user_ids_changed( user_id, from_token, ) - defer.returnValue((200, { - "changed": list(changed), - })) + defer.returnValue((200, results)) class OneTimeKeyServlet(RestServlet): @@ -215,10 +209,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns( - "/keys/claim$", - releases=() - ) + PATTERNS = client_v2_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index fd2a3d69d4..2a6ea3df5f 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -13,24 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer -) from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_room_id, + format_event_for_client_v2_without_room_id, + serialize_event, ) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from ._base import client_v2_patterns -import logging - logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$", releases=()) + PATTERNS = client_v2_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() @@ -88,7 +87,7 @@ class NotificationsServlet(RestServlet): pa["topological_ordering"], pa["stream_ordering"] ) returned_push_actions.append(returned_pa) - next_token = pa["stream_ordering"] + next_token = str(pa["stream_ordering"]) defer.returnValue((200, { "notifications": returned_push_actions, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index aa1cae8e1e..01c90aa2a3 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -14,15 +14,15 @@ # limitations under the License. -from ._base import client_v2_patterns +import logging + +from twisted.internet import defer -from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string -from twisted.internet import defer - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 2f8784fe06..a6e582a5ae 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 1fbff2edd8..de370cac45 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns - -import logging +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1421c18152..d6cf915d86 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -14,25 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hmac +import logging +from hashlib import sha1 + +from six import string_types + from twisted.internet import defer import synapse -from synapse.api.auth import get_access_token_from_request, has_access_token +import synapse.types from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError +from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, ) from synapse.util.msisdn import phone_number_to_msisdn - -from ._base import client_v2_patterns - -import logging -import hmac -from hashlib import sha1 -from synapse.util.async import run_on_reactor from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.threepids import check_3pid_allowed +from ._base import client_v2_patterns, interactive_auth_handler # We ought to be using hmac.compare_digest() but on older pythons it doesn't # exist. It's a _really minor_ security flaw to use plain string comparison @@ -64,10 +68,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'email', 'send_attempt' ]) + if not check_3pid_allowed(self.hs, "email", body['email']): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( 'email', body['email'] ) @@ -95,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_request(body, [ + assert_params_in_dict(body, [ 'id_server', 'client_secret', 'country', 'phone_number', 'send_attempt', @@ -103,6 +112,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + ) + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( 'msisdn', msisdn ) @@ -170,13 +184,13 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_room_member_handler() self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): - yield run_on_reactor() - body = parse_json_object_from_request(request) kind = "user" @@ -196,20 +210,20 @@ class RegisterRestServlet(RestServlet): # in sessions. Pull out the username/password provided to us. desired_password = None if 'password' in body: - if (not isinstance(body['password'], basestring) or + if (not isinstance(body['password'], string_types) or len(body['password']) > 512): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None if 'username' in body: - if (not isinstance(body['username'], basestring) or + if (not isinstance(body['username'], string_types) or len(body['username']) > 512): raise SynapseError(400, "Invalid username") desired_username = body['username'] appservice = None - if has_access_token(request): + if self.auth.has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -221,15 +235,30 @@ class RegisterRestServlet(RestServlet): # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. desired_username = body.get("user", desired_username) - access_token = get_access_token_from_request(request) - if isinstance(desired_username, basestring): + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + + access_token = self.auth.get_access_token_from_request(request) + + if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( desired_username, access_token, body ) defer.returnValue((200, result)) # we throw for non 200 responses return + # for either shared secret or regular registration, downcase the + # provided username before attempting to register it. This should mean + # that people who try to register with upper-case in their usernames + # don't get a nasty surprise. (Note that we treat username + # case-insenstively in login, so they are free to carry on imagining + # that their username is CrAzYh4cKeR if that keeps them happy) + if desired_username is not None: + desired_username = desired_username.lower() + # == Shared Secret Registration == (e.g. create new user scripts) if 'mac' in body: # FIXME: Should we really be determining if this is shared secret @@ -286,34 +315,66 @@ class RegisterRestServlet(RestServlet): if 'x_show_msisdn' in body and body['x_show_msisdn']: show_msisdn = True + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = 'email' in self.hs.config.registrations_require_3pid + require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + + flows = [] if self.hs.config.enable_registration_captcha: - flows = [ - [LoginType.RECAPTCHA], - [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], - ] + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([[LoginType.RECAPTCHA]]) + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) + if show_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs + if not require_email: + flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + # always let users provide both MSISDN & email flows.extend([ - [LoginType.MSISDN, LoginType.RECAPTCHA], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], ]) else: - flows = [ - [LoginType.DUMMY], - [LoginType.EMAIL_IDENTITY], - ] + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + flows.extend([[LoginType.DUMMY]]) + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.extend([[LoginType.EMAIL_IDENTITY]]) + if show_msisdn: + # only support the MSISDN-only flow if we don't require email 3PIDs + if not require_email or require_msisdn: + flows.extend([[LoginType.MSISDN]]) + # always let users provide both MSISDN & email flows.extend([ - [LoginType.MSISDN], - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], + [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] ]) - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return + # Check that we're not trying to register a denied 3pid. + # + # the user-facing checks will probably already have happened in + # /register/email/requestToken when we requested a 3pid, but that's not + # guaranteed. + + if auth_result: + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]['medium'] + address = auth_result[login_type]['address'] + + if not check_3pid_allowed(self.hs, medium, address): + raise SynapseError( + 403, "Third party identifier is not allowed", + Codes.THREEPID_DENIED, + ) if registered_user_id is not None: logger.info( @@ -325,14 +386,15 @@ class RegisterRestServlet(RestServlet): add_msisdn = False else: # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", - Codes.MISSING_PARAM) + assert_params_in_dict(params, ["password"]) desired_username = params.get("username", None) new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + if desired_username is not None: + desired_username = desired_username.lower() + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, @@ -383,15 +445,24 @@ class RegisterRestServlet(RestServlet): def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") + if not username: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON, + ) - user = username.encode("utf-8") + # use the username from the original request rather than the + # downcased one in `username` for the mac calculation + user = body["username"].encode("utf-8") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface got_mac = str(body["mac"]) + # FIXME this is different to the /v1/register endpoint, which + # includes the password and admin flag in the hashed text. Why are + # these different? want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), msg=user, digestmod=sha1, ).hexdigest() @@ -492,11 +563,14 @@ class RegisterRestServlet(RestServlet): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') - if any(x not in threepid for x in reqd): - # This will only happen if the ID server returns a malformed response - logger.info("Can't add incomplete 3pid") - defer.returnValue() + try: + assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) + except SynapseError as ex: + if ex.errcode == Codes.MISSING_PARAM: + # This will only happen if the ID server returns a malformed response + logger.info("Can't add incomplete 3pid") + defer.returnValue(None) + raise yield self.auth_handler.add_threepid( user_id, @@ -523,25 +597,28 @@ class RegisterRestServlet(RestServlet): Args: (str) user_id: full canonical @user:id (object) params: registration parameters, from which we pull - device_id and initial_device_name + device_id, initial_device_name and inhibit_login Returns: defer.Deferred: (object) dictionary for response from /register """ - device_id = yield self._register_device(user_id, params) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + } + if not params.get("inhibit_login", False): + device_id = yield self._register_device(user_id, params) - access_token = ( - yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - initial_display_name=params.get("initial_device_display_name") + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) ) - ) - defer.returnValue({ - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - }) + result.update({ + "access_token": access_token, + "device_id": device_id, + }) + defer.returnValue(result) def _register_device(self, user_id, params): """Register a device for a user. @@ -566,7 +643,7 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: - defer.returnValue((403, "Guest access is disabled")) + raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 8903e12405..95d2a71ec2 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -13,13 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from six import string_types +from six.moves import http_client -import logging +from twisted.internet import defer + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) + assert_params_in_dict(body, ("reason", "score")) + + if not isinstance(body["reason"], string_types): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'reason' must be a string", + Codes.BAD_JSON, + ) + if not isinstance(body["score"], int): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'score' must be an integer", + Codes.BAD_JSON, + ) yield self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, - reason=body.get("reason"), + reason=body["reason"], content=body, received_ts=self.clock.time_msec(), ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d607bd2970..a9e9a47a0b 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", - releases=[], v2_alpha=False + v2_alpha=False ) def __init__(self, hs): @@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): super(SendToDeviceRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs.get_clock()) + self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() def on_PUT(self, request, message_type, txn_id): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 83e209d18f..8aa06faf23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -13,27 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools +import logging + +from canonicaljson import json + from twisted.internet import defer -from synapse.http.servlet import ( - RestServlet, parse_string, parse_integer, parse_boolean +from synapse.api.constants import PresenceState +from synapse.api.errors import SynapseError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.events.utils import ( + format_event_for_client_v2_without_room_id, + serialize_event, ) from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken -from synapse.events.utils import ( - serialize_event, format_event_for_client_v2_without_room_id, -) -from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION -from synapse.api.errors import SynapseError -from synapse.api.constants import PresenceState -from ._base import client_v2_patterns -from ._base import set_timeline_upper_limit - -import itertools -import logging -import ujson as json +from ._base import client_v2_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) @@ -85,6 +84,7 @@ class SyncRestServlet(RestServlet): self.clock = hs.get_clock() self.filtering = hs.get_filtering() self.presence_handler = hs.get_presence_handler() + self._server_notices_sender = hs.get_server_notices_sender() @defer.inlineCallbacks def on_GET(self, request): @@ -110,7 +110,7 @@ class SyncRestServlet(RestServlet): filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) - logger.info( + logger.debug( "/sync: user=%r, timeout=%r, since=%r," " set_presence=%r, filter_id=%r, device_id=%r" % ( user, timeout, since, set_presence, filter_id, device_id @@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet): filter_object = json.loads(filter_id) set_timeline_upper_limit(filter_object, self.hs.config.filter_timeline_limit) - except: + except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) @@ -149,6 +149,9 @@ class SyncRestServlet(RestServlet): else: since_token = None + # send any outstanding server notices to the user. + yield self._server_notices_sender.on_user_syncing(user.to_string()) + affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: @@ -164,27 +167,35 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() + response_content = self.encode_response( + time_now, sync_result, requester.access_token_id, filter + ) + + defer.returnValue((200, response_content)) - joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id, filter.event_fields + @staticmethod + def encode_response(time_now, sync_result, access_token_id, filter): + joined = SyncRestServlet.encode_joined( + sync_result.joined, time_now, access_token_id, filter.event_fields ) - invited = self.encode_invited( - sync_result.invited, time_now, requester.access_token_id + invited = SyncRestServlet.encode_invited( + sync_result.invited, time_now, access_token_id, ) - archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, + archived = SyncRestServlet.encode_archived( + sync_result.archived, time_now, access_token_id, filter.event_fields, ) - response_content = { + return { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, "device_lists": { - "changed": list(sync_result.device_lists), + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), }, - "presence": self.encode_presence( + "presence": SyncRestServlet.encode_presence( sync_result.presence, time_now ), "rooms": { @@ -192,13 +203,17 @@ class SyncRestServlet(RestServlet): "invite": invited, "leave": archived, }, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), } - defer.returnValue((200, response_content)) - - def encode_presence(self, events, time_now): + @staticmethod + def encode_presence(events, time_now): return { "events": [ { @@ -212,7 +227,8 @@ class SyncRestServlet(RestServlet): ] } - def encode_joined(self, rooms, time_now, token_id, event_fields): + @staticmethod + def encode_joined(rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -231,13 +247,14 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = self.encode_room( + joined[room.room_id] = SyncRestServlet.encode_room( room, time_now, token_id, only_fields=event_fields ) return joined - def encode_invited(self, rooms, time_now, token_id): + @staticmethod + def encode_invited(rooms, time_now, token_id): """ Encode the invited rooms in a sync result @@ -270,7 +287,8 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id, event_fields): + @staticmethod + def encode_archived(rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -289,7 +307,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = self.encode_room( + joined[room.room_id] = SyncRestServlet.encode_room( room, time_now, token_id, joined=False, only_fields=event_fields ) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index dac8603b07..4fea614e95 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_patterns - -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError +import logging from twisted.internet import defer -import logging +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 6fceb23e26..d9d379182e 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -20,13 +20,14 @@ from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet + from ._base import client_v2_patterns logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -43,8 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,8 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -90,8 +89,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", - releases=()) + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 6e012da4aa..cac0624ba7 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request + from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ class UserDirectorySearchRestServlet(RestServlet): try: search_term = body["search_term"] - except: + except Exception: raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index e984ea47db..6ac2987b98 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.servlet import RestServlet - import logging import re +from synapse.http.servlet import RestServlet + logger = logging.getLogger(__name__) @@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet): "r0.0.1", "r0.1.0", "r0.2.0", + "r0.3.0", ] }) diff --git a/synapse/rest/consent/__init__.py b/synapse/rest/consent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/rest/consent/__init__.py diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py new file mode 100644 index 0000000000..147ff7d79b --- /dev/null +++ b/synapse/rest/consent/consent_resource.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hmac +import logging +from hashlib import sha256 +from os import path + +from six.moves import http_client + +import jinja2 +from jinja2 import TemplateNotFound + +from twisted.internet import defer +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + +from synapse.api.errors import NotFoundError, StoreError, SynapseError +from synapse.config import ConfigError +from synapse.http.server import finish_request, wrap_html_request_handler +from synapse.http.servlet import parse_string +from synapse.types import UserID + +# language to use for the templates. TODO: figure this out from Accept-Language +TEMPLATE_LANGUAGE = "en" + +logger = logging.getLogger(__name__) + +# use hmac.compare_digest if we have it (python 2.7.7), else just use equality +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + def compare_digest(a, b): + return a == b + + +class ConsentResource(Resource): + """A twisted Resource to display a privacy policy and gather consent to it + + When accessed via GET, returns the privacy policy via a template. + + When accessed via POST, records the user's consent in the database and + displays a success page. + + The config should include a template_dir setting which contains templates + for the HTML. The directory should contain one subdirectory per language + (eg, 'en', 'fr'), and each language directory should contain the policy + document (named as '<version>.html') and a success page (success.html). + + Both forms take a set of parameters from the browser. For the POST form, + these are normally sent as form parameters (but may be query-params); for + GET requests they must be query params. These are: + + u: the complete mxid, or the localpart of the user giving their + consent. Required for both GET (where it is used as an input to the + template) and for POST (where it is used to find the row in the db + to update). + + h: hmac_sha256(secret, u), where 'secret' is the privacy_secret in the + config file. If it doesn't match, the request is 403ed. + + v: the version of the privacy policy being agreed to. + + For GET: optional, and defaults to whatever was set in the config + file. Used to choose the version of the policy to pick from the + templates directory. + + For POST: required; gives the value to be recorded in the database + against the user. + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): homeserver + """ + Resource.__init__(self) + + self.hs = hs + self.store = hs.get_datastore() + + # this is required by the request_handler wrapper + self.clock = hs.get_clock() + + self._default_consent_version = hs.config.user_consent_version + if self._default_consent_version is None: + raise ConfigError( + "Consent resource is enabled but user_consent section is " + "missing in config file.", + ) + + # daemonize changes the cwd to /, so make the path absolute now. + consent_template_directory = path.abspath( + hs.config.user_consent_template_dir, + ) + if not path.isdir(consent_template_directory): + raise ConfigError( + "Could not find template directory '%s'" % ( + consent_template_directory, + ), + ) + + loader = jinja2.FileSystemLoader(consent_template_directory) + self._jinja_env = jinja2.Environment( + loader=loader, + autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']), + ) + + if hs.config.form_secret is None: + raise ConfigError( + "Consent resource is enabled but form_secret is not set in " + "config file. It should be set to an arbitrary secret string.", + ) + + self._hmac_secret = hs.config.form_secret.encode("utf-8") + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @wrap_html_request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + """ + Args: + request (twisted.web.http.Request): + """ + + version = parse_string(request, "v", + default=self._default_consent_version) + username = parse_string(request, "u", required=True) + userhmac = parse_string(request, "h", required=True) + + self._check_hash(username, userhmac) + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + u = yield self.store.get_user_by_id(qualified_user_id) + if u is None: + raise NotFoundError("Unknown user") + + try: + self._render_template( + request, "%s.html" % (version,), + user=username, userhmac=userhmac, version=version, + has_consented=(u["consent_version"] == version), + ) + except TemplateNotFound: + raise NotFoundError("Unknown policy version") + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @wrap_html_request_handler + @defer.inlineCallbacks + def _async_render_POST(self, request): + """ + Args: + request (twisted.web.http.Request): + """ + version = parse_string(request, "v", required=True) + username = parse_string(request, "u", required=True) + userhmac = parse_string(request, "h", required=True) + + self._check_hash(username, userhmac) + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + try: + yield self.store.user_set_consent_version(qualified_user_id, version) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("Unknown user") + + try: + self._render_template(request, "success.html") + except TemplateNotFound: + raise NotFoundError("success.html not found") + + def _render_template(self, request, template_name, **template_args): + # get_template checks for ".." so we don't need to worry too much + # about path traversal here. + template_html = self._jinja_env.get_template( + path.join(TEMPLATE_LANGUAGE, template_name) + ) + html_bytes = template_html.render(**template_args).encode("utf8") + + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) + request.write(html_bytes) + finish_request(request) + + def _check_hash(self, userid, userhmac): + want_mac = hmac.new( + key=self._hmac_secret, + msg=userid, + digestmod=sha256, + ).hexdigest() + + if not compare_digest(want_mac, userhmac): + raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index bd4fea5774..b9ee6e1c13 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -14,14 +14,16 @@ # limitations under the License. -from twisted.web.resource import Resource -from synapse.http.server import respond_with_json_bytes +import logging + +from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from canonicaljson import encode_canonical_json + from OpenSSL import crypto -import logging +from twisted.web.resource import Resource +from synapse.http.server import respond_with_json_bytes logger = logging.getLogger(__name__) @@ -49,7 +51,6 @@ class LocalKey(Resource): """ def __init__(self, hs): - self.version_string = hs.version_string self.response_body = encode_canonical_json( self.response_json_object(hs.config) ) @@ -84,7 +85,6 @@ class LocalKey(Resource): def render_GET(self, request): return respond_with_json_bytes( request, 200, self.response_body, - version_string=self.version_string ) def getChild(self, name, request): diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index a07224148c..3491fd2118 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.web.resource import Resource + from .local_key_resource import LocalKey from .remote_key_resource import RemoteKey diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index be68d9a096..ec0ec7b431 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -14,13 +14,15 @@ # limitations under the License. -from twisted.web.resource import Resource -from synapse.http.server import respond_with_json_bytes +import logging + +from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from canonicaljson import encode_canonical_json -import logging +from twisted.web.resource import Resource + +from synapse.http.server import respond_with_json_bytes logger = logging.getLogger(__name__) @@ -63,7 +65,6 @@ class LocalKey(Resource): isLeaf = True def __init__(self, hs): - self.version_string = hs.version_string self.config = hs.config self.clock = hs.clock self.update_response_body(self.clock.time_msec()) @@ -115,5 +116,4 @@ class LocalKey(Resource): self.update_response_body(time_now) return respond_with_json_bytes( request, 200, self.response_body, - version_string=self.version_string ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9fe2013657..7d67e4b064 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import request_handler, respond_with_json_bytes -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.api.errors import SynapseError, Codes -from synapse.crypto.keyring import KeyLookupError +import logging +from io import BytesIO +from twisted.internet import defer from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer +from synapse.api.errors import Codes, SynapseError +from synapse.crypto.keyring import KeyLookupError +from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler +from synapse.http.servlet import parse_integer, parse_json_object_from_request -from io import BytesIO -import logging logger = logging.getLogger(__name__) @@ -91,14 +91,14 @@ class RemoteKey(Resource): def __init__(self, hs): self.keyring = hs.get_keyring() self.store = hs.get_datastore() - self.version_string = hs.version_string self.clock = hs.get_clock() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist def render_GET(self, request): self.async_render_GET(request) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def async_render_GET(self, request): if len(request.postpath) == 1: @@ -123,7 +123,7 @@ class RemoteKey(Resource): self.async_render_POST(request) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def async_render_POST(self, request): content = parse_json_object_from_request(request) @@ -137,6 +137,13 @@ class RemoteKey(Resource): logger.info("Handling query for keys %r", query) store_queries = [] for server_name, key_ids in query.items(): + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + logger.debug("Federation denied with %s", server_name) + continue + if not key_ids: key_ids = (None,) for key_id in key_ids: @@ -213,7 +220,7 @@ class RemoteKey(Resource): ) except KeyLookupError as e: logger.info("Failed to fetch key: %s", e) - except: + except Exception: logger.exception("Failed to get key for %r", server_name) yield self.query_keys( request, query, query_remote_on_cache_miss=False @@ -232,5 +239,4 @@ class RemoteKey(Resource): respond_with_json_bytes( request, 200, result_io.getvalue(), - version_string=self.version_string ) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 956bd5da75..f255f2883f 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -13,21 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json_bytes, finish_request - -from synapse.api.errors import ( - Codes, cs_error -) - -from twisted.protocols.basic import FileSender -from twisted.web import server, resource - import base64 -import simplejson as json import logging import os import re +from canonicaljson import json + +from twisted.protocols.basic import FileSender +from twisted.web import resource, server + +from synapse.api.errors import Codes, cs_error +from synapse.http.server import finish_request, respond_with_json_bytes + logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index b9600f2167..65f4bd2910 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -13,22 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json, finish_request -from synapse.api.errors import ( - cs_error, Codes, SynapseError -) +import logging +import os +import urllib + +from six.moves.urllib import parse as urlparse from twisted.internet import defer from twisted.protocols.basic import FileSender +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.util import logcontext from synapse.util.stringutils import is_ascii -import os - -import logging -import urllib -import urlparse - logger = logging.getLogger(__name__) @@ -44,7 +42,7 @@ def parse_media_id(request): except UnicodeDecodeError: pass return server_name, media_id, file_name - except: + except Exception: raise SynapseError( 404, "Invalid media id token %r" % (request.postpath,), @@ -69,42 +67,133 @@ def respond_with_file(request, media_type, file_path, logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) - if upload_name: - if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) if file_size is None: stat = os.stat(file_path) file_size = stat.st_size - request.setHeader( - b"Content-Length", b"%d" % (file_size,) - ) + add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield FileSender().beginFileTransfer(f, request) + yield logcontext.make_deferred_yieldable( + FileSender().beginFileTransfer(f, request) + ) finish_request(request) else: respond_404(request) + + +def add_file_headers(request, media_type, file_size, upload_name): + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request (twisted.web.http.Request) + media_type (str): The media/content type. + file_size (int): Size in bytes of the media, if known. + upload_name (str): The name of the requested file, if any. + """ + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + if upload_name: + if is_ascii(upload_name): + request.setHeader( + b"Content-Disposition", + b"inline; filename=%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + else: + request.setHeader( + b"Content-Disposition", + b"inline; filename*=utf-8''%s" % ( + urllib.quote(upload_name.encode("utf-8")), + ), + ) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" + ) + + request.setHeader( + b"Content-Length", b"%d" % (file_size,) + ) + + +@defer.inlineCallbacks +def respond_with_responder(request, responder, media_type, file_size, upload_name=None): + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request (twisted.web.http.Request) + responder (Responder|None) + media_type (str): The media/content type. + file_size (int|None): Size in bytes of the media. If not known it should be None + upload_name (str|None): The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + logger.debug("Responding to media request with responder %s") + add_file_headers(request, media_type, file_size, upload_name) + with responder: + yield responder.write_to_consumer(request) + finish_request(request) + + +class Responder(object): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + def write_to_consumer(self, consumer): + """Stream response into consumer + + Args: + consumer (IConsumer) + + Returns: + Deferred: Resolves once the response has finished being written + """ + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class FileInfo(object): + """Details about a requested/uploaded file. + + Attributes: + server_name (str): The server name where the media originated from, + or None if local. + file_id (str): The local ID of the file. For local files this is the + same as the media_id + url_cache (bool): If the file is for the url preview cache + thumbnail (bool): Whether the file is a thumbnail or not. + thumbnail_width (int) + thumbnail_height (int) + thumbnail_method (str) + thumbnail_type (str): Content type of thumbnail, e.g. image/png + """ + def __init__(self, server_name, file_id, url_cache=False, + thumbnail=False, thumbnail_width=None, thumbnail_height=None, + thumbnail_method=None, thumbnail_type=None): + self.server_name = server_name + self.file_id = file_id + self.url_cache = url_cache + self.thumbnail = thumbnail + self.thumbnail_width = thumbnail_width + self.thumbnail_height = thumbnail_height + self.thumbnail_method = thumbnail_method + self.thumbnail_type = thumbnail_type diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 6879249c8a..fbfa85f74f 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.http.servlet +import logging -from ._base import parse_media_id, respond_with_file, respond_404 +from twisted.internet import defer from twisted.web.resource import Resource -from synapse.http.server import request_handler, set_cors_headers - from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer -import logging +import synapse.http.servlet +from synapse.http.server import set_cors_headers, wrap_json_request_handler + +from ._base import parse_media_id, respond_404 logger = logging.getLogger(__name__) @@ -32,18 +32,17 @@ class DownloadResource(Resource): def __init__(self, hs, media_repo): Resource.__init__(self) - self.filepaths = media_repo.filepaths self.media_repo = media_repo self.server_name = hs.hostname - self.store = hs.get_datastore() - self.version_string = hs.version_string + + # this is expected by @wrap_json_request_handler self.clock = hs.get_clock() def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def _async_render_GET(self, request): set_cors_headers(request) @@ -57,59 +56,16 @@ class DownloadResource(Resource): ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: - yield self._respond_local_file(request, media_id, name) + yield self.media_repo.get_local_media(request, media_id, name) else: - yield self._respond_remote_file( - request, server_name, media_id, name - ) - - @defer.inlineCallbacks - def _respond_local_file(self, request, media_id, name): - media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: - respond_404(request) - return - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_filepath(media_id) - else: - file_path = self.filepaths.local_media_filepath(media_id) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) - - @defer.inlineCallbacks - def _respond_remote_file(self, request, server_name, media_id, name): - # don't forward requests for remote media if allow_remote is false - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True) - if not allow_remote: - logger.info( - "Rejecting request for remote media %s/%s due to allow_remote", - server_name, media_id, - ) - respond_404(request) - return - - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - media_type = media_info["media_type"] - media_length = media_info["media_length"] - filesystem_id = media_info["filesystem_id"] - upload_name = name if name else media_info["upload_name"] - - file_path = self.filepaths.remote_media_filepath( - server_name, filesystem_id - ) - - yield respond_with_file( - request, media_type, file_path, media_length, - upload_name=upload_name, - ) + allow_remote = synapse.http.servlet.parse_boolean( + request, "allow_remote", default=True) + if not allow_remote: + logger.info( + "Rejecting request for remote media %s/%s due to allow_remote", + server_name, media_id, + ) + respond_404(request) + return + + yield self.media_repo.get_remote_media(request, server_name, media_id, name) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d92b7ff337..c8586fa280 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -13,79 +13,201 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os +import re + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +def _wrap_in_base_path(func): + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + @functools.wraps(func) + def _wrapped(self, *args, **kwargs): + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return _wrapped class MediaFilePaths(object): + """Describes where files are stored on disk. - def __init__(self, base_path): - self.base_path = base_path + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ - def default_thumbnail(self, default_top_level, default_sub_type, width, - height, content_type, method): + def __init__(self, primary_base_path): + self.base_path = primary_base_path + + def default_thumbnail_rel(self, default_top_level, default_sub_type, width, + height, content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "default_thumbnails", default_top_level, + "default_thumbnails", default_top_level, default_sub_type, file_name ) - def local_media_filepath(self, media_id): + default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) + + def local_media_filepath_rel(self, media_id): return os.path.join( - self.base_path, "local_content", + "local_content", media_id[0:2], media_id[2:4], media_id[4:] ) - def local_media_thumbnail(self, media_id, width, height, content_type, - method): + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + def local_media_thumbnail_rel(self, media_id, width, height, content_type, + method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "local_thumbnails", + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) - def remote_media_filepath(self, server_name, file_id): + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - self.base_path, "remote_content", server_name, + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) - def remote_media_thumbnail(self, server_name, file_id, width, height, - content_type, method): + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + def remote_media_thumbnail_rel(self, server_name, file_id, width, height, + content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], file_name ) + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], ) - def url_cache_filepath(self, media_id): - return os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], media_id[4:] - ) + def url_cache_filepath_rel(self, media_id): + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + media_id[:10], media_id[11:] + ) + else: + return os.path.join( + "url_cache", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + def url_cache_filepath_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache", + media_id[0:2], + ), + ] + + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, + method): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf - def url_cache_thumbnail(self, media_id, width, height, content_type, - method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) - return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name - ) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + media_id[:10], media_id[11:], + file_name + ) + else: + return os.path.join( + "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + file_name + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + def url_cache_thumbnail_directory(self, media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ) + else: + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + def url_cache_thumbnail_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], + ), + ] diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py index 66f2b6bd30..bdbd8d50dd 100644 --- a/synapse/rest/media/v1/identicon_resource.py +++ b/synapse/rest/media/v1/identicon_resource.py @@ -13,8 +13,11 @@ # limitations under the License. from pydenticon import Generator + from twisted.web.resource import Resource +from synapse.http.servlet import parse_integer + FOREGROUND = [ "rgb(45,79,255)", "rgb(254,180,44)", @@ -55,8 +58,8 @@ class IdenticonResource(Resource): def render_GET(self, request): name = "/".join(request.postpath) - width = int(request.args.get("width", [96])[0]) - height = int(request.args.get("height", [96])[0]) + width = parse_integer(request, "width", default=96) + height = parse_integer(request, "height", default=96) identicon_bytes = self.generate_identicon(name, width, height) request.setHeader(b"Content-Type", b"image/png") request.setHeader( diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0ea1248ce6..30242c525a 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,45 +14,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, threads +import cgi +import errno +import logging +import os +import shutil + +from six import iteritems +from six.moves.urllib import parse as urlparse + import twisted.internet.error import twisted.web.http +from twisted.internet import defer, threads from twisted.web.resource import Resource -from .upload_resource import UploadResource -from .download_resource import DownloadResource -from .thumbnail_resource import ThumbnailResource -from .identicon_resource import IdenticonResource -from .preview_url_resource import PreviewUrlResource -from .filepath import MediaFilePaths -from .thumbnailer import Thumbnailer - +from synapse.api.errors import ( + FederationDeniedError, + HttpResponseException, + NotFoundError, + SynapseError, +) from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.util.stringutils import random_string -from synapse.api.errors import SynapseError, HttpResponseException, \ - NotFoundError - from synapse.util.async import Linearizer -from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import is_ascii, random_string -import os -import errno -import shutil - -import cgi -import logging -import urlparse +from ._base import FileInfo, respond_404, respond_with_responder +from .download_resource import DownloadResource +from .filepath import MediaFilePaths +from .identicon_resource import IdenticonResource +from .media_storage import MediaStorage +from .preview_url_resource import PreviewUrlResource +from .storage_provider import StorageProviderWrapper +from .thumbnail_resource import ThumbnailResource +from .thumbnailer import Thumbnailer +from .upload_resource import UploadResource logger = logging.getLogger(__name__) -UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository(object): def __init__(self, hs): + self.hs = hs self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() @@ -59,46 +67,90 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = MediaFilePaths(hs.config.media_store_path) + + self.primary_base_path = hs.config.media_store_path + self.filepaths = MediaFilePaths(self.primary_base_path) + self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() + self.recently_accessed_locals = set() + + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] + + for clz, provider_config, wrapper_config in hs.config.media_storage_providers: + backend = clz(hs, provider_config) + provider = StorageProviderWrapper( + backend, + store_local=wrapper_config.store_local, + store_remote=wrapper_config.store_remote, + store_synchronous=wrapper_config.store_synchronous, + ) + storage_providers.append(provider) + + self.media_storage = MediaStorage( + self.hs, self.primary_base_path, self.filepaths, storage_providers, + ) self.clock.looping_call( - self._update_recently_accessed_remotes, - UPDATE_RECENTLY_ACCESSED_REMOTES_TS + self._update_recently_accessed, + UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks - def _update_recently_accessed_remotes(self): - media = self.recently_accessed_remotes + def _update_recently_accessed(self): + remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() + yield self.store.update_cached_last_access_time( - media, self.clock.time_msec() + local_media, remote_media, self.clock.time_msec() ) - @staticmethod - def _makedirs(filepath): - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) + def mark_recently_accessed(self, server_name, media_id): + """Mark the given media as recently accessed. + + Args: + server_name (str|None): Origin server of media, or None if local + media_id (str): The media ID of the content + """ + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type(str): The content type of the file + upload_name(str): The name of the file + content: A file like object that is the content to store + content_length(int): The length of the content + auth_user(str): The user_id of the uploader + + Returns: + Deferred[str]: The mxc url of the stored content + """ media_id = random_string(24) - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) + file_info = FileInfo( + server_name=None, + file_id=media_id, + ) - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) + fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) @@ -110,131 +162,275 @@ class MediaRepository(object): media_length=content_length, user_id=auth_user, ) - media_info = { - "media_type": media_type, - "media_length": content_length, - } - yield self._generate_local_thumbnails(media_id, media_info) + yield self._generate_thumbnails( + None, media_id, media_id, media_type, + ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks - def get_remote_media(self, server_name, media_id): + def get_local_media(self, request, media_id, name): + """Responds to reqests for local media, if exists, or returns 404. + + Args: + request(twisted.web.http.Request) + media_id (str): The media ID of the content. (This is the same as + the file_id for local content.) + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + media_info = yield self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo( + None, media_id, + url_cache=url_cache, + ) + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + + @defer.inlineCallbacks + def get_remote_media(self, request, server_name, media_id, name): + """Respond to requests for remote media. + + Args: + request(twisted.web.http.Request) + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + name (str|None): Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Deferred: Resolves once a response has successfully been written + to request + """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + with (yield self.remote_media_linearizer.queue(key)): + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + yield respond_with_responder( + request, responder, media_type, media_length, upload_name, + ) + else: + respond_404(request) + + @defer.inlineCallbacks + def get_remote_media_info(self, server_name, media_id): + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[dict]: The media_info of the file + """ + if ( + self.federation_domain_whitelist is not None and + server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): - media_info = yield self._get_remote_media_impl(server_name, media_id) + responder, media_info = yield self._get_remote_media_impl( + server_name, media_id, + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name (str): Remote server_name where the media originated. + media_id (str): The media ID of the content (as defined by the + remote server). + + Returns: + Deferred[(Responder, media_info)] + """ media_info = yield self.store.get_cached_remote_media( server_name, media_id ) - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) - elif media_info["quarantined_by"]: - raise NotFoundError() + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise genereate a new + # one. + if media_info: + file_id = media_info["filesystem_id"] else: - self.recently_accessed_remotes.add((server_name, media_id)) - yield self.store.update_cached_last_access_time( - [(server_name, media_id)], self.clock.time_msec() - ) - defer.returnValue(media_info) + file_id = random_string(24) - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id): - file_id = random_string(24) + file_info = FileInfo(server_name, file_id) + + # If we have an entry in the DB, try and look for it + if media_info: + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + defer.returnValue((responder, media_info)) + + # Failed to find the file anywhere, lets download it. - fname = self.filepaths.remote_media_filepath( - server_name, file_id + media_info = yield self._download_remote_file( + server_name, media_id, file_id ) - self._makedirs(fname) - try: - with open(fname, "wb") as f: - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + responder = yield self.media_storage.fetch_media(file_info) + defer.returnValue((responder, media_info)) + + @defer.inlineCallbacks + def _download_remote_file(self, server_name, media_id, file_id): + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name (str): Originating server + media_id (str): The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id (str): Local file ID + + Returns: + Deferred[MediaInfo] + """ + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + ) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join(( + "/_matrix/media/v1/download", server_name, media_id, + )) + try: + length, headers = yield self.client.get_file( + server_name, request_path, output_stream=f, + max_size=self.max_upload_size, args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false", + } + ) + except twisted.internet.error.DNSLookupError as e: + logger.warn("HTTP error fetching remote media %s/%s: %r", + server_name, media_id, e) + raise NotFoundError() + + except HttpResponseException as e: + logger.warn("HTTP error fetching remote media %s/%s: %s", + server_name, media_id, e.response) + if e.code == twisted.web.http.NOT_FOUND: + raise SynapseError.from_http_response_exception(e) + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise + except NotRetryingDestination: + logger.warn("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception("Failed to fetch remote media %s/%s", + server_name, media_id) + raise SynapseError(502, "Failed to fetch remote media") + + yield finish() + + media_type = headers["Content-Type"][0] + + time_now_ms = self.clock.time_msec() + + content_disposition = headers.get("Content-Disposition", None) + if content_disposition: + _, params = cgi.parse_header(content_disposition[0],) + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get("filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith("utf-8''"): + upload_name = upload_name_utf8[7:] + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get("filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii + + if upload_name: + upload_name = urlparse.unquote(upload_name) try: - length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, args={ - # tell the remote server to 404 if it doesn't - # recognise the server_name, to make sure we don't - # end up with a routing loop. - "allow_remote": "false", - } - ) - except twisted.internet.error.DNSLookupError as e: - logger.warn("HTTP error fetching remote media %s/%s: %r", - server_name, media_id, e) - raise NotFoundError() - - except HttpResponseException as e: - logger.warn("HTTP error fetching remote media %s/%s: %s", - server_name, media_id, e.response) - if e.code == twisted.web.http.NOT_FOUND: - raise SynapseError.from_http_response_exception(e) - raise SynapseError(502, "Failed to fetch remote media") - - except SynapseError: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) - raise - except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) - raise SynapseError(502, "Failed to fetch remote media") - except Exception: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) - raise SynapseError(502, "Failed to fetch remote media") - - media_type = headers["Content-Type"][0] - time_now_ms = self.clock.time_msec() - - content_disposition = headers.get("Content-Disposition", None) - if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith("utf-8''"): - upload_name = upload_name_utf8[7:] - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get("filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii - - if upload_name: - upload_name = urlparse.unquote(upload_name) - try: - upload_name = upload_name.decode("utf-8") - except UnicodeDecodeError: - upload_name = None - else: - upload_name = None - - logger.info("Stored remote media in file %r", fname) - - yield self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - except: - os.remove(fname) - raise + upload_name = upload_name.decode("utf-8") + except UnicodeDecodeError: + upload_name = None + else: + upload_name = None + + logger.info("Stored remote media in file %r", fname) + + yield self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) media_info = { "media_type": media_type, @@ -244,8 +440,8 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_remote_thumbnails( - server_name, media_id, media_info + yield self._generate_thumbnails( + server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) @@ -253,9 +449,8 @@ class MediaRepository(object): def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): - thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -267,75 +462,125 @@ class MediaRepository(object): return if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: - t_len = None + t_byte_source = None - return t_len + return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type): - input_path = self.filepaths.local_media_filepath(media_id) - - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) + t_method, t_type, url_cache): + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + None, media_id, url_cache=url_cache, + )) - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=None, + file_id=media_id, + url_cache=url_cache, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + server_name, file_id, url_cache=False, + )) - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=server_name, + file_id=media_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + ) + + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info, url_cache=False): - media_type = media_info["media_type"] + def _generate_thumbnails(self, server_name, media_id, file_id, media_type, + url_cache=False): + """Generate and store thumbnails for an image. + + Args: + server_name (str|None): The server name if remote media, else None if local + media_id (str): The media ID of the content. (This is the same as + the file_id for local content) + file_id (str): Local file ID + media_type (str): The content type of the file + url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Deferred[dict]: Dict with "width" and "height" keys of original image + """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return - if url_cache: - input_path = self.filepaths.url_cache_filepath(media_id) - else: - input_path = self.filepaths.local_media_filepath(media_id) + input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( + server_name, file_id, url_cache=url_cache, + )) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width @@ -348,135 +593,68 @@ class MediaRepository(object): ) return - local_thumbnails = [] - - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails = {} + for r_width, r_height, r_method, r_type in requirements: + if r_method == "crop": + thumbnails.setdefault((r_width, r_height, r_type), r_method) + elif r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, r_type)] = r_method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in iteritems(thumbnails): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.crop, + t_width, t_height, t_type, )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len + elif t_method == "scale": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.scale, + t_width, t_height, t_type, )) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + try: + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - remote_thumbnails = [] + output_path = yield self.media_storage.store_file( + t_byte_source, file_info, + ) + finally: + t_byte_source.close() - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height + t_len = os.path.getsize(output_path) - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ + # Write to database + if server_name: + yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) + else: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) defer.returnValue({ "width": m_width, @@ -497,6 +675,8 @@ class MediaRepository(object): logger.info("Deleting: %r", key) + # TODO: Should we delete from the backup store + with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: @@ -571,7 +751,11 @@ class MediaRepositoryResource(Resource): self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) + self.putChild("thumbnail", ThumbnailResource( + hs, media_repo, media_repo.media_storage, + )) self.putChild("identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) + self.putChild("preview_url", PreviewUrlResource( + hs, media_repo, media_repo.media_storage, + )) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py new file mode 100644 index 0000000000..b25993fcb5 --- /dev/null +++ b/synapse/rest/media/v1/media_storage.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vecotr Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import os +import shutil +import sys + +import six + +from twisted.internet import defer, threads +from twisted.protocols.basic import FileSender + +from synapse.util.file_consumer import BackgroundFileConsumer +from synapse.util.logcontext import make_deferred_yieldable + +from ._base import Responder + +logger = logging.getLogger(__name__) + + +class MediaStorage(object): + """Responsible for storing/fetching files from local sources. + + Args: + hs (synapse.server.Homeserver) + local_media_directory (str): Base path where we store media on disk + filepaths (MediaFilePaths) + storage_providers ([StorageProvider]): List of StorageProvider that are + used to fetch and store files. + """ + + def __init__(self, hs, local_media_directory, filepaths, storage_providers): + self.hs = hs + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + + @defer.inlineCallbacks + def store_file(self, source, file_info): + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info (FileInfo): Info about the file to store + + Returns: + Deferred[str]: the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + _write_file_synchronously, source, f, + )) + yield finish_cb() + + defer.returnValue(fname) + + @contextlib.contextmanager + def store_into_file(self, file_info): + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns a Deferred. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info (FileInfo): Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + yield finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + finished_called = [False] + + @defer.inlineCallbacks + def finish(): + for provider in self.storage_providers: + yield provider.store_file(path, file_info) + + finished_called[0] = True + + try: + with open(fname, "wb") as f: + yield f, fname, finish + except Exception: + t, v, tb = sys.exc_info() + try: + os.remove(fname) + except Exception: + pass + six.reraise(t, v, tb) + + if not finished_called: + raise Exception("Finished callback not called") + + @defer.inlineCallbacks + def fetch_media(self, file_info): + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info (FileInfo) + + Returns: + Deferred[Responder|None]: Returns a Responder if the file was found, + otherwise None. + """ + + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + defer.returnValue(FileResponder(open(local_path, "rb"))) + + for provider in self.storage_providers: + res = yield provider.fetch(path, file_info) + if res: + defer.returnValue(res) + + defer.returnValue(None) + + @defer.inlineCallbacks + def ensure_media_is_in_local_cache(self, file_info): + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info (FileInfo) + + Returns: + Deferred[str]: Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + defer.returnValue(local_path) + + dirname = os.path.dirname(local_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + for provider in self.storage_providers: + res = yield provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "w"), self.hs.get_reactor()) + yield res.write_to_consumer(consumer) + yield consumer.wait() + defer.returnValue(local_path) + + raise Exception("file could not be found") + + def _file_info_to_path(self, file_info): + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + + Args: + file_info (FileInfo) + + Returns: + str + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id, + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + method=file_info.thumbnail_method + ) + return self.filepaths.local_media_filepath_rel( + file_info.file_id, + ) + + +def _write_file_synchronously(source, dest): + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file (file): A file like object to be streamed ot the client, + is closed when finished streaming. + """ + def __init__(self, open_file): + self.open_file = open_file + + def write_to_consumer(self, consumer): + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.open_file.close() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b81a336c5d..b70b15c4c2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -12,80 +12,98 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import cgi +import datetime +import errno +import fnmatch +import itertools +import logging +import os +import re +import shutil +import sys +import traceback + +from six import string_types +from six.moves import urllib_parse as urlparse + +from canonicaljson import json -from twisted.web.server import NOT_DONE_YET from twisted.internet import defer from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET -from synapse.api.errors import ( - SynapseError, Codes, -) -from synapse.util.stringutils import random_string -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.api.errors import Codes, SynapseError from synapse.http.client import SpiderHttpClient from synapse.http.server import ( - request_handler, respond_with_json_bytes + respond_with_json, + respond_with_json_bytes, + wrap_json_request_handler, ) +from synapse.http.servlet import parse_integer, parse_string from synapse.util.async import ObservableDeferred -from synapse.util.stringutils import is_ascii +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.util.stringutils import is_ascii, random_string -import os -import re -import fnmatch -import cgi -import ujson as json -import urlparse -import itertools +from ._base import FileInfo -import logging logger = logging.getLogger(__name__) class PreviewUrlResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.auth = hs.get_auth() self.clock = hs.get_clock() - self.version_string = hs.version_string self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path + self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist - # simple memory cache mapping urls to OG metadata - self.cache = ExpiringCache( + # memory cache mapping urls to an ObservableDeferred returning + # JSON-encoded OG metadata + self._cache = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self.cache.start() + self._cache.start() - self.downloads = {} + self._cleaner_loop = self.clock.looping_call( + self._expire_url_cache_data, 10 * 1000 + ) + + def render_OPTIONS(self, request): + return respond_with_json(request, 200, {}, send_cors=True) def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? requester = yield self.auth.get_user_by_req(request) - url = request.args.get("url")[0] + url = parse_string(request, "url") if "ts" in request.args: - ts = int(request.args.get("ts")[0]) + ts = parse_integer(request, "ts") else: ts = self.clock.time_msec() + # XXX: we could move this into _do_preview if we wanted. url_tuple = urlparse.urlsplit(url) for entry in self.url_preview_url_blacklist: match = True @@ -118,53 +136,62 @@ class PreviewUrlResource(Resource): Codes.UNKNOWN ) - # first check the memory cache - good to handle all the clients on this - # HS thundering away to preview the same URL at the same time. - og = self.cache.get(url) - if og: - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) - return + # the in-memory cache: + # * ensures that only one request is active at a time + # * takes load off the DB for the thundering herds + # * also caches any failures (unlike the DB) so we don't keep + # requesting the same endpoint + + observable = self._cache.get(url) + + if not observable: + download = run_in_background( + self._do_preview, + url, requester.user, ts, + ) + observable = ObservableDeferred( + download, + consumeErrors=True + ) + self._cache[url] = observable + else: + logger.info("Returning cached response") + + og = yield make_deferred_yieldable(observable.observe()) + respond_with_json_bytes(request, 200, og, send_cors=True) - # then check the URL cache in the DB (which will also provide us with + @defer.inlineCallbacks + def _do_preview(self, url, user, ts): + """Check the db, and download the URL and build a preview + + Args: + url (str): + user (str): + ts (int): + + Returns: + Deferred[str]: json-encoded og data + """ + # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( cache_result and - cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["expires_ts"] > ts and cache_result["response_code"] / 100 == 2 ): - respond_with_json_bytes( - request, 200, cache_result["og"].encode('utf-8'), - send_cors=True - ) + defer.returnValue(cache_result["og"]) return - # Ensure only one download for a given URL is active at a time - download = self.downloads.get(url) - if download is None: - download = self._download_url(url, requester.user) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[url] = download - - @download.addBoth - def callback(media_info): - del self.downloads[url] - return media_info - media_info = yield download.observe() - - # FIXME: we should probably update our cache now anyway, so that - # even if the OG calculation raises, we don't keep hammering on the - # remote server. For now, leave it uncached to aid debugging OG - # calculation problems + media_info = yield self._download_url(url, user) logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): - dims = yield self.media_repo._generate_local_thumbnails( - media_info['filesystem_id'], media_info, url_cache=True, + file_id = media_info['filesystem_id'] + dims = yield self.media_repo._generate_thumbnails( + None, file_id, file_id, media_info["media_type"], + url_cache=True, ) og = { @@ -204,13 +231,15 @@ class PreviewUrlResource(Resource): # just rely on the caching on the master request to speed things up. if 'og:image' in og and og['og:image']: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), requester.user + _rebase_url(og['og:image'], media_info['uri']), user ) if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info, url_cache=True, + file_id = image_info['filesystem_id'] + dims = yield self.media_repo._generate_thumbnails( + None, file_id, file_id, image_info["media_type"], + url_cache=True, ) if dims: og["og:image:width"] = dims['width'] @@ -231,21 +260,20 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - # store OG in ephemeral in-memory cache - self.cache[url] = og + jsonog = json.dumps(og) # store OG in history-aware DB cache yield self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], - media_info["expires"], - json.dumps(og), + media_info["expires"] + media_info["created_ts"], + jsonog, media_info["filesystem_id"], media_info["created_ts"], ) - respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) + defer.returnValue(jsonog) @defer.inlineCallbacks def _download_url(self, url, user): @@ -253,21 +281,36 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - # XXX: horrible duplication with base_resource's _download_remote_file() - file_id = random_string(24) + file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fname = self.filepaths.url_cache_filepath(file_id) - self.media_repo._makedirs(fname) + file_info = FileInfo( + server_name=None, + file_id=file_id, + url_cache=True, + ) - try: - with open(fname, "wb") as f: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + except Exception as e: # FIXME: pass through 404s and other error messages nicely + logger.warn("Error downloading %s: %r", url, e) + raise SynapseError( + 500, "Failed to download content: %s" % ( + traceback.format_exception_only(sys.exc_type, e), + ), + Codes.UNKNOWN, + ) + yield finish() - media_type = headers["Content-Type"][0] + try: + if "Content-Type" in headers: + media_type = headers["Content-Type"][0] + else: + media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) @@ -307,11 +350,11 @@ class PreviewUrlResource(Resource): ) except Exception as e: - os.remove(fname) - raise SynapseError( - 500, ("Failed to download content: %s" % e), - Codes.UNKNOWN - ) + logger.error("Error handling downloaded %s: %r", url, e) + # TODO: we really ought to delete the downloaded file in this + # case, since we won't have recorded it in the db, and will + # therefore not expire it. + raise defer.returnValue({ "media_type": media_type, @@ -328,6 +371,95 @@ class PreviewUrlResource(Resource): "etag": headers["ETag"][0] if "ETag" in headers else None, }) + @defer.inlineCallbacks + def _expire_url_cache_data(self): + """Clean up expired url cache content, media and thumbnails. + """ + # TODO: Delete from backup media store + + now = self.clock.time_msec() + + logger.info("Running url preview cache expiry") + + if not (yield self.store.has_completed_background_updates()): + logger.info("Still running DB updates; skipping expiry") + return + + # First we delete expired url cache entries + media_ids = yield self.store.get_expired_url_cache(now) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache(removed_media) + + if removed_media: + logger.info("Deleted %d entries from url cache", len(removed_media)) + + # Now we delete old images associated with the url cache. + # These may be cached for a bit on the client (i.e., they + # may have a room open with a preview url thing open). + # So we wait a couple of days before deleting, just in case. + expire_before = now - 2 * 24 * 60 * 60 * 1000 + media_ids = yield self.store.get_url_cache_media_before(expire_before) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) + try: + shutil.rmtree(thumbnail_dir) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except Exception: + pass + + yield self.store.delete_url_cache_media(removed_media) + + logger.info("Deleted %d media from url cache", len(removed_media)) + def decode_and_calc_og(body, media_uri, request_encoding=None): from lxml import etree @@ -425,7 +557,14 @@ def _calc_og(tree, media_uri): from lxml import etree TAGS_TO_REMOVE = ( - "header", "nav", "aside", "footer", "script", "style", etree.Comment + "header", + "nav", + "aside", + "footer", + "script", + "noscript", + "style", + etree.Comment ) # Split all the text nodes into paragraphs (by splitting on new @@ -452,8 +591,8 @@ def _iterate_over_text(tree, *tags_to_ignore): # to be returned. elements = iter([tree]) while True: - el = elements.next() - if isinstance(el, basestring): + el = next(elements) + if isinstance(el, string_types): yield el elif el is not None and el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py new file mode 100644 index 0000000000..7b9f8b4d79 --- /dev/null +++ b/synapse/rest/media/v1/storage_provider.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil + +from twisted.internet import defer, threads + +from synapse.config._base import Config +from synapse.util.logcontext import run_in_background + +from .media_storage import FileResponder + +logger = logging.getLogger(__name__) + + +class StorageProvider(object): + """A storage provider is a service that can store uploaded media and + retrieve them. + """ + def store_file(self, path, file_info): + """Store the file described by file_info. The actual contents can be + retrieved by reading the file in file_info.upload_path. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred + """ + pass + + def fetch(self, path, file_info): + """Attempt to fetch the file described by file_info and stream it + into writer. + + Args: + path (str): Relative path of file in local cache + file_info (FileInfo) + + Returns: + Deferred(Responder): Returns a Responder if the provider has the file, + otherwise returns None. + """ + pass + + +class StorageProviderWrapper(StorageProvider): + """Wraps a storage provider and provides various config options + + Args: + backend (StorageProvider) + store_local (bool): Whether to store new local files or not. + store_synchronous (bool): Whether to wait for file to be successfully + uploaded, or todo the upload in the backgroud. + store_remote (bool): Whether remote media should be uploaded + """ + def __init__(self, backend, store_local, store_synchronous, store_remote): + self.backend = backend + self.store_local = store_local + self.store_synchronous = store_synchronous + self.store_remote = store_remote + + def store_file(self, path, file_info): + if not file_info.server_name and not self.store_local: + return defer.succeed(None) + + if file_info.server_name and not self.store_remote: + return defer.succeed(None) + + if self.store_synchronous: + return self.backend.store_file(path, file_info) + else: + # TODO: Handle errors. + def store(): + try: + return self.backend.store_file(path, file_info) + except Exception: + logger.exception("Error storing file") + run_in_background(store) + return defer.succeed(None) + + def fetch(self, path, file_info): + return self.backend.fetch(path, file_info) + + +class FileStorageProviderBackend(StorageProvider): + """A storage provider that stores files in a directory on a filesystem. + + Args: + hs (HomeServer) + config: The config returned by `parse_config`. + """ + + def __init__(self, hs, config): + self.cache_directory = hs.config.media_store_path + self.base_directory = config + + def store_file(self, path, file_info): + """See StorageProvider.store_file""" + + primary_fname = os.path.join(self.cache_directory, path) + backup_fname = os.path.join(self.base_directory, path) + + dirname = os.path.dirname(backup_fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + + return threads.deferToThread( + shutil.copyfile, primary_fname, backup_fname, + ) + + def fetch(self, path, file_info): + """See StorageProvider.fetch""" + + backup_fname = os.path.join(self.base_directory, path) + if os.path.isfile(backup_fname): + return FileResponder(open(backup_fname, "rb")) + + @staticmethod + def parse_config(config): + """Called on startup to parse config supplied. This should parse + the config and raise if there is a problem. + + The returned value is passed into the constructor. + + In this case we only care about a single param, the directory, so let's + just pull that out. + """ + return Config.ensure_directory(config["directory"]) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 68d56b2b10..5305e9175f 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -14,15 +14,22 @@ # limitations under the License. -from ._base import parse_media_id, respond_404, respond_with_file -from twisted.web.resource import Resource -from synapse.http.servlet import parse_string, parse_integer -from synapse.http.server import request_handler, set_cors_headers +import logging -from twisted.web.server import NOT_DONE_YET from twisted.internet import defer +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET -import logging +from synapse.http.server import set_cors_headers, wrap_json_request_handler +from synapse.http.servlet import parse_integer, parse_string + +from ._base import ( + FileInfo, + parse_media_id, + respond_404, + respond_with_file, + respond_with_responder, +) logger = logging.getLogger(__name__) @@ -30,22 +37,21 @@ logger = logging.getLogger(__name__) class ThumbnailResource(Resource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs, media_repo, media_storage): Resource.__init__(self) self.store = hs.get_datastore() - self.filepaths = media_repo.filepaths self.media_repo = media_repo + self.media_storage = media_storage self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - self.version_string = hs.version_string self.clock = hs.get_clock() def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def _async_render_GET(self, request): set_cors_headers(request) @@ -64,6 +70,7 @@ class ThumbnailResource(Resource): yield self._respond_local_thumbnail( request, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(None, media_id) else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( @@ -75,20 +82,20 @@ class ThumbnailResource(Resource): request, server_name, media_id, width, height, method, m_type ) + self.media_repo.mark_recently_accessed(server_name, media_id) @defer.inlineCallbacks def _respond_local_thumbnail(self, request, media_id, width, height, method, m_type): media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info: + respond_404(request) + return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") respond_404(request) return - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) @@ -96,42 +103,39 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method, - ) - else: - file_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path) - else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], ) + t_type = file_info.thumbnail_type + t_length = thumbnail_info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) + else: + logger.info("Couldn't find any generated thumbnails") + respond_404(request) + @defer.inlineCallbacks def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, desired_height, desired_method, desired_type): media_info = yield self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info: + respond_404(request) + return + if media_info["quarantined_by"]: + logger.info("Media is quarantined") respond_404(request) return - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.local_media_filepath(media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: @@ -141,46 +145,43 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - if media_info["url_cache"]: - # TODO: Check the file still exists, if it doesn't we can redownload - # it from the url `media_info["url_cache"]` - file_path = self.filepaths.url_cache_thumbnail( - media_id, desired_width, desired_height, desired_type, - desired_method, - ) - else: - file_path = self.filepaths.local_media_thumbnail( - media_id, desired_width, desired_height, desired_type, - desired_method, - ) - yield respond_with_file(request, desired_type, file_path) - return - - logger.debug("We don't have a local thumbnail of that size. Generating") + file_info = FileInfo( + server_name=None, file_id=media_id, + url_cache=media_info["url_cache"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( - media_id, desired_width, desired_height, desired_method, desired_type + media_id, desired_width, desired_height, desired_method, desired_type, + url_cache=media_info["url_cache"], ) if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + logger.warn("Failed to generate thumbnail") + respond_404(request) @defer.inlineCallbacks def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, desired_width, desired_height, desired_method, desired_type): - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -195,14 +196,24 @@ class ThumbnailResource(Resource): t_type = info["thumbnail_type"] == desired_type if t_w and t_h and t_method and t_type: - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, desired_width, desired_height, - desired_type, desired_method, + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=info["thumbnail_width"], + thumbnail_height=info["thumbnail_height"], + thumbnail_type=info["thumbnail_type"], + thumbnail_method=info["thumbnail_method"], ) - yield respond_with_file(request, desired_type, file_path) - return - logger.debug("We don't have a local thumbnail of that size. Generating") + t_type = file_info.thumbnail_type + t_length = info["thumbnail_length"] + + responder = yield self.media_storage.fetch_media(file_info) + if responder: + yield respond_with_responder(request, responder, t_type, t_length) + return + + logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( @@ -213,22 +224,16 @@ class ThumbnailResource(Resource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - yield self._respond_default_thumbnail( - request, media_info, desired_width, desired_height, - desired_method, desired_type, - ) + logger.warn("Failed to generate thumbnail") + respond_404(request) @defer.inlineCallbacks def _respond_remote_thumbnail(self, request, server_name, media_id, width, height, method, m_type): # TODO: Don't download the whole remote file - # We should proxy the thumbnail from the remote server instead. - media_info = yield self.media_repo.get_remote_media(server_name, media_id) - - # if media_info["media_type"] == "image/svg+xml": - # file_path = self.filepaths.remote_media_filepath(server_name, media_id) - # yield respond_with_file(request, media_info["media_type"], file_path) - # return + # We should proxy the thumbnail from the remote server instead of + # downloading the remote file and generating our own thumbnails. + media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id, @@ -238,59 +243,23 @@ class ThumbnailResource(Resource): thumbnail_info = self._select_thumbnail( width, height, method, m_type, thumbnail_infos ) - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - file_id = thumbnail_info["filesystem_id"] + file_info = FileInfo( + server_name=server_name, file_id=media_info["filesystem_id"], + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + ) + + t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - file_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) + responder = yield self.media_storage.fetch_media(file_info) + yield respond_with_responder(request, responder, t_type, t_length) else: - yield self._respond_default_thumbnail( - request, media_info, width, height, method, m_type, - ) - - @defer.inlineCallbacks - def _respond_default_thumbnail(self, request, media_info, width, height, - method, m_type): - # XXX: how is this meant to work? store.get_default_thumbnails - # appears to always return [] so won't this always 404? - media_type = media_info["media_type"] - top_level_type = media_type.split("/")[0] - sub_type = media_type.split("/")[-1].split(";")[0] - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, sub_type, - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - top_level_type, "_default", - ) - if not thumbnail_infos: - thumbnail_infos = yield self.store.get_default_thumbnails( - "_default", "_default", - ) - if not thumbnail_infos: + logger.info("Failed to find any generated thumbnails") respond_404(request) - return - - thumbnail_info = self._select_thumbnail( - width, height, "crop", m_type, thumbnail_infos - ) - - t_width = thumbnail_info["thumbnail_width"] - t_height = thumbnail_info["thumbnail_height"] - t_type = thumbnail_info["thumbnail_type"] - t_method = thumbnail_info["thumbnail_method"] - t_length = thumbnail_info["thumbnail_length"] - - file_path = self.filepaths.default_thumbnail( - top_level_type, sub_type, t_width, t_height, t_type, t_method, - ) - yield respond_with_file(request, t_type, file_path, t_length) def _select_thumbnail(self, desired_width, desired_height, desired_method, desired_type, thumbnail_infos): diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3868d4f65f..a4b26c2587 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import PIL.Image as Image +import logging from io import BytesIO -import logging +import PIL.Image as Image logger = logging.getLogger(__name__) @@ -50,12 +50,16 @@ class Thumbnailer(object): else: return ((max_height * self.width) // self.height, max_height) - def scale(self, output_path, width, height, output_type): - """Rescales the image to the given dimensions""" + def scale(self, width, height, output_type): + """Rescales the image to the given dimensions. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk + """ scaled = self.image.resize((width, height), Image.ANTIALIAS) - return self.save_image(scaled, output_type, output_path) + return self._encode_image(scaled, output_type) - def crop(self, output_path, width, height, output_type): + def crop(self, width, height, output_type): """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -65,6 +69,9 @@ class Thumbnailer(object): Args: max_width: The largest possible width. max_height: The larget possible height. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width @@ -82,13 +89,9 @@ class Thumbnailer(object): crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self.save_image(cropped, output_type, output_path) + return self._encode_image(cropped, output_type) - def save_image(self, output_image, output_type, output_path): + def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) - output_bytes = output_bytes_io.getvalue() - with open(output_path, "wb") as output_file: - output_file.write(output_bytes) - logger.info("Stored thumbnail in file %r", output_path) - return len(output_bytes) + return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 4ab33f73bf..9b22d204a6 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import respond_with_json, request_handler - -from synapse.api.errors import SynapseError +import logging -from twisted.web.server import NOT_DONE_YET from twisted.internet import defer - from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET -import logging +from synapse.api.errors import SynapseError +from synapse.http.server import respond_with_json, wrap_json_request_handler +from synapse.http.servlet import parse_string logger = logging.getLogger(__name__) @@ -40,7 +39,6 @@ class UploadResource(Resource): self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.max_upload_size - self.version_string = hs.version_string self.clock = hs.get_clock() def render_POST(self, request): @@ -51,7 +49,7 @@ class UploadResource(Resource): respond_with_json(request, 200, {}, send_cors=True) return NOT_DONE_YET - @request_handler() + @wrap_json_request_handler @defer.inlineCallbacks def _async_render_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -68,10 +66,10 @@ class UploadResource(Resource): code=413, ) - upload_name = request.args.get("filename", None) + upload_name = parse_string(request, "filename") if upload_name: try: - upload_name = upload_name[0].decode('UTF-8') + upload_name = upload_name.decode('UTF-8') except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), @@ -81,19 +79,19 @@ class UploadResource(Resource): headers = request.requestHeaders if headers.hasHeader("Content-Type"): - media_type = headers.getRawHeaders("Content-Type")[0] + media_type = headers.getRawHeaders(b"Content-Type")[0] else: raise SynapseError( msg="Upload request missing 'Content-Type'", code=400, ) - # if headers.hasHeader("Content-Disposition"): - # disposition = headers.getRawHeaders("Content-Disposition")[0] + # if headers.hasHeader(b"Content-Disposition"): + # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content.read(), + media_type, upload_name, request.content, content_length, requester.user ) diff --git a/synapse/secrets.py b/synapse/secrets.py new file mode 100644 index 0000000000..f397daaa5e --- /dev/null +++ b/synapse/secrets.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Injectable secrets module for Synapse. + +See https://docs.python.org/3/library/secrets.html#module-secrets for the API +used in Python 3.6, and the API emulated in Python 2.7. +""" + +import six + +if six.PY3: + import secrets + + def Secrets(): + return secrets + + +else: + + import os + import binascii + + class Secrets(object): + def token_bytes(self, nbytes=32): + return os.urandom(nbytes) + + def token_hex(self, nbytes=32): + return binascii.hexlify(self.token_bytes(nbytes)) diff --git a/synapse/server.py b/synapse/server.py index a38e5179e0..140be9ebe8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,32 +31,55 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory -from synapse.federation import initialize_http_replication +from synapse.events.spamcheck import SpamChecker +from synapse.federation.federation_client import FederationClient +from synapse.federation.federation_server import ( + FederationHandlerRegistry, + FederationServer, +) from synapse.federation.send_queue import FederationRemoteSendQueue -from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue +from synapse.federation.transport.client import TransportLayerClient +from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer +from synapse.groups.groups_server import GroupsServerHandler from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler, MacaroonGeneartor -from synapse.handlers.devicemessage import DeviceMessageHandler +from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler +from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.events import EventHandler, EventStreamHandler +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.initial_sync import InitialSyncHandler +from synapse.handlers.message import EventCreationHandler, MessageHandler +from synapse.handlers.pagination import PaginationHandler from synapse.handlers.presence import PresenceHandler +from synapse.handlers.profile import ProfileHandler +from synapse.handlers.read_marker import ReadMarkerHandler +from synapse.handlers.receipts import ReceiptsHandler +from synapse.handlers.room import RoomContextHandler, RoomCreationHandler from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.room_member import RoomMemberMasterHandler +from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler -from synapse.handlers.events import EventHandler, EventStreamHandler -from synapse.handlers.initial_sync import InitialSyncHandler -from synapse.handlers.receipts import ReceiptsHandler -from synapse.handlers.read_marker import ReadMarkerHandler -from synapse.handlers.user_directory import UserDirectoyHandler -from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory +from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool -from synapse.rest.media.v1.media_repository import MediaRepository -from synapse.state import StateHandler +from synapse.rest.media.v1.media_repository import ( + MediaRepository, + MediaRepositoryResource, +) +from synapse.secrets import Secrets +from synapse.server_notices.server_notices_manager import ServerNoticesManager +from synapse.server_notices.server_notices_sender import ServerNoticesSender +from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender +from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStore from synapse.streams.events import EventSources from synapse.util import Clock @@ -82,21 +105,21 @@ class HomeServer(object): which must be implemented by the subclass. This code may call any of the required "get" methods on the instance to obtain the sub-dependencies that one requires. + + Attributes: + config (synapse.config.homeserver.HomeserverConfig): """ DEPENDENCIES = [ - 'config', - 'clock', 'http_client', 'db_pool', - 'persistence_service', - 'replication_layer', - 'datastore', + 'federation_client', + 'federation_server', 'handlers', - 'v1auth', 'auth', - 'rest_servlet_factory', + 'room_creation_handler', 'state_handler', + 'state_resolution_handler', 'presence_handler', 'sync_handler', 'typing_handler', @@ -111,19 +134,12 @@ class HomeServer(object): 'application_service_scheduler', 'application_service_handler', 'device_message_handler', + 'profile_handler', + 'event_creation_handler', + 'deactivate_account_handler', + 'set_password_handler', 'notifier', - 'distributor', - 'client_resource', - 'resource_for_federation', - 'resource_for_static_content', - 'resource_for_web_client', - 'resource_for_content_repo', - 'resource_for_server_key', - 'resource_for_server_key_v2', - 'resource_for_media_repository', - 'resource_for_metrics', 'event_sources', - 'ratelimiter', 'keyring', 'pusherpool', 'event_builder_factory', @@ -131,6 +147,7 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'media_repository_resource', 'federation_transport_client', 'federation_sender', 'receipts_handler', @@ -139,17 +156,34 @@ class HomeServer(object): 'read_marker_handler', 'action_generator', 'user_directory_handler', + 'groups_local_handler', + 'groups_server_handler', + 'groups_attestation_signing', + 'groups_attestation_renewer', + 'secrets', + 'spam_checker', + 'room_member_handler', + 'federation_registry', + 'server_notices_manager', + 'server_notices_sender', + 'message_handler', + 'pagination_handler', + 'room_context_handler', ] - def __init__(self, hostname, **kwargs): + def __init__(self, hostname, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. """ + if not reactor: + from twisted.internet import reactor + + self._reactor = reactor self.hostname = hostname self._building = {} - self.clock = Clock() + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -162,6 +196,12 @@ class HomeServer(object): self.datastore = DataStore(self.get_db_conn(), self) logger.info("Finished setting up.") + def get_reactor(self): + """ + Fetch the Twisted reactor in use by this HomeServer. + """ + return self._reactor + def get_ip_from_request(self, request): # X-Forwarded-For is handled by our custom request type. return request.getClientIP() @@ -172,8 +212,26 @@ class HomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname - def build_replication_layer(self): - return initialize_http_replication(self) + def get_clock(self): + return self.clock + + def get_datastore(self): + return self.datastore + + def get_config(self): + return self.config + + def get_distributor(self): + return self.distributor + + def get_ratelimiter(self): + return self.ratelimiter + + def build_federation_client(self): + return FederationClient(self) + + def build_federation_server(self): + return FederationServer(self) def build_handlers(self): return Handlers(self) @@ -194,18 +252,15 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf + def build_room_creation_handler(self): + return RoomCreationHandler(self) def build_state_handler(self): return StateHandler(self) + def build_state_resolution_handler(self): + return StateResolutionHandler(self) + def build_presence_handler(self): return PresenceHandler(self) @@ -222,7 +277,7 @@ class HomeServer(object): return AuthHandler(self) def build_macaroon_generator(self): - return MacaroonGeneartor(self) + return MacaroonGenerator(self) def build_device_handler(self): return DeviceHandler(self) @@ -251,6 +306,18 @@ class HomeServer(object): def build_initial_sync_handler(self): return InitialSyncHandler(self) + def build_profile_handler(self): + return ProfileHandler(self) + + def build_event_creation_handler(self): + return EventCreationHandler(self) + + def build_deactivate_account_handler(self): + return DeactivateAccountHandler(self) + + def build_set_password_handler(self): + return SetPasswordHandler(self) + def build_event_sources(self): return EventSources(self) @@ -277,9 +344,32 @@ class HomeServer(object): return adbapi.ConnectionPool( name, + cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) ) + def get_db_conn(self, run_new_connection=True): + """Makes a new connection to the database, skipping the db pool + + Returns: + Connection: a connection object implementing the PEP-249 spec + """ + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def build_media_repository_resource(self): + # build the media repo resource. This indirects through the HomeServer + # to ensure that we only have a single instance of + return MediaRepositoryResource(self) + def build_media_repository(self): return MediaRepository(self) @@ -307,7 +397,52 @@ class HomeServer(object): return ActionGenerator(self) def build_user_directory_handler(self): - return UserDirectoyHandler(self) + return UserDirectoryHandler(self) + + def build_groups_local_handler(self): + return GroupsLocalHandler(self) + + def build_groups_server_handler(self): + return GroupsServerHandler(self) + + def build_groups_attestation_signing(self): + return GroupAttestationSigning(self) + + def build_groups_attestation_renewer(self): + return GroupAttestionRenewer(self) + + def build_secrets(self): + return Secrets() + + def build_spam_checker(self): + return SpamChecker(self) + + def build_room_member_handler(self): + if self.config.worker_app: + return RoomMemberWorkerHandler(self) + return RoomMemberMasterHandler(self) + + def build_federation_registry(self): + return FederationHandlerRegistry() + + def build_server_notices_manager(self): + if self.config.worker_app: + raise Exception("Workers cannot send server notices") + return ServerNoticesManager(self) + + def build_server_notices_sender(self): + if self.config.worker_app: + return WorkerServerNoticesSender(self) + return ServerNoticesSender(self) + + def build_message_handler(self): + return MessageHandler(self) + + def build_pagination_handler(self): + return PaginationHandler(self) + + def build_room_context_handler(self): + return RoomContextHandler(self) def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9570df5537..ce28486233 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,12 +1,25 @@ import synapse.api.auth +import synapse.config.homeserver +import synapse.federation.transaction_queue +import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth +import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys -import synapse.storage +import synapse.handlers.set_password +import synapse.rest.media.v1.media_repository +import synapse.server_notices.server_notices_manager +import synapse.server_notices.server_notices_sender import synapse.state +import synapse.storage + class HomeServer(object): + @property + def config(self) -> synapse.config.homeserver.HomeServerConfig: + pass + def get_auth(self) -> synapse.api.auth.Auth: pass @@ -27,3 +40,36 @@ class HomeServer(object): def get_state_handler(self) -> synapse.state.StateHandler: pass + + def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: + pass + + def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + pass + + def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: + pass + + def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler: + pass + + def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + pass + + def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: + pass + + def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + pass + + def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + pass + + def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + pass + + def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: + pass + + def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: + pass diff --git a/synapse/server_notices/__init__.py b/synapse/server_notices/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/server_notices/__init__.py diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py new file mode 100644 index 0000000000..5e3044d164 --- /dev/null +++ b/synapse/server_notices/consent_server_notices.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from six import iteritems, string_types + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.api.urls import ConsentURIBuilder +from synapse.config import ConfigError +from synapse.types import get_localpart_from_id + +logger = logging.getLogger(__name__) + + +class ConsentServerNotices(object): + """Keeps track of whether we need to send users server_notices about + privacy policy consent, and sends one if we do. + """ + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer): + """ + self._server_notices_manager = hs.get_server_notices_manager() + self._store = hs.get_datastore() + + self._users_in_progress = set() + + self._current_consent_version = hs.config.user_consent_version + self._server_notice_content = hs.config.user_consent_server_notice_content + self._send_to_guests = hs.config.user_consent_server_notice_to_guests + + if self._server_notice_content is not None: + if not self._server_notices_manager.is_enabled(): + raise ConfigError( + "user_consent configuration requires server notices, but " + "server notices are not enabled.", + ) + if 'body' not in self._server_notice_content: + raise ConfigError( + "user_consent server_notice_consent must contain a 'body' " + "key.", + ) + + self._consent_uri_builder = ConsentURIBuilder(hs.config) + + @defer.inlineCallbacks + def maybe_send_server_notice_to_user(self, user_id): + """Check if we need to send a notice to this user, and does so if so + + Args: + user_id (str): user to check + + Returns: + Deferred + """ + if self._server_notice_content is None: + # not enabled + return + + # make sure we don't send two messages to the same user at once + if user_id in self._users_in_progress: + return + self._users_in_progress.add(user_id) + try: + u = yield self._store.get_user_by_id(user_id) + + if u["is_guest"] and not self._send_to_guests: + # don't send to guests + return + + if u["consent_version"] == self._current_consent_version: + # user has already consented + return + + if u["consent_server_notice_sent"] == self._current_consent_version: + # we've already sent a notice to the user + return + + # need to send a message. + try: + consent_uri = self._consent_uri_builder.build_user_consent_uri( + get_localpart_from_id(user_id), + ) + content = copy_with_str_subst( + self._server_notice_content, { + 'consent_uri': consent_uri, + }, + ) + yield self._server_notices_manager.send_notice( + user_id, content, + ) + yield self._store.user_set_consent_server_notice_sent( + user_id, self._current_consent_version, + ) + except SynapseError as e: + logger.error("Error sending server notice about user consent: %s", e) + finally: + self._users_in_progress.remove(user_id) + + +def copy_with_str_subst(x, substitutions): + """Deep-copy a structure, carrying out string substitions on any strings + + Args: + x (object): structure to be copied + substitutions (object): substitutions to be made - passed into the + string '%' operator + + Returns: + copy of x + """ + if isinstance(x, string_types): + return x % substitutions + if isinstance(x, dict): + return { + k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x) + } + if isinstance(x, (list, tuple)): + return [copy_with_str_subst(y) for y in x] + + # assume it's uninterested and can be shallow-copied. + return x diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py new file mode 100644 index 0000000000..a26deace53 --- /dev/null +++ b/synapse/server_notices/server_notices_manager.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership, RoomCreationPreset +from synapse.types import create_requester +from synapse.util.caches.descriptors import cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +class ServerNoticesManager(object): + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer): + """ + + self._store = hs.get_datastore() + self._config = hs.config + self._room_creation_handler = hs.get_room_creation_handler() + self._event_creation_handler = hs.get_event_creation_handler() + self._is_mine_id = hs.is_mine_id + + def is_enabled(self): + """Checks if server notices are enabled on this server. + + Returns: + bool + """ + return self._config.server_notices_mxid is not None + + @defer.inlineCallbacks + def send_notice(self, user_id, event_content): + """Send a notice to the given user + + Creates the server notices room, if none exists. + + Args: + user_id (str): mxid of user to send event to. + event_content (dict): content of event to send + + Returns: + Deferred[None] + """ + room_id = yield self.get_notice_room_for_user(user_id) + + system_mxid = self._config.server_notices_mxid + requester = create_requester(system_mxid) + + logger.info("Sending server notice to %s", user_id) + + yield self._event_creation_handler.create_and_send_nonmember_event( + requester, { + "type": EventTypes.Message, + "room_id": room_id, + "sender": system_mxid, + "content": event_content, + }, + ratelimit=False, + ) + + @cachedInlineCallbacks() + def get_notice_room_for_user(self, user_id): + """Get the room for notices for a given user + + If we have not yet created a notice room for this user, create it + + Args: + user_id (str): complete user id for the user we want a room for + + Returns: + str: room id of notice room. + """ + if not self.is_enabled(): + raise Exception("Server notices not enabled") + + assert self._is_mine_id(user_id), \ + "Cannot send server notices to remote users" + + rooms = yield self._store.get_rooms_for_user_where_membership_is( + user_id, [Membership.INVITE, Membership.JOIN], + ) + system_mxid = self._config.server_notices_mxid + for room in rooms: + # it's worth noting that there is an asymmetry here in that we + # expect the user to be invited or joined, but the system user must + # be joined. This is kinda deliberate, in that if somebody somehow + # manages to invite the system user to a room, that doesn't make it + # the server notices room. + user_ids = yield self._store.get_users_in_room(room.room_id) + if system_mxid in user_ids: + # we found a room which our user shares with the system notice + # user + logger.info("Using room %s", room.room_id) + defer.returnValue(room.room_id) + + # apparently no existing notice room: create a new one + logger.info("Creating server notices room for %s", user_id) + + # see if we want to override the profile info for the server user. + # note that if we want to override either the display name or the + # avatar, we have to use both. + join_profile = None + if ( + self._config.server_notices_mxid_display_name is not None or + self._config.server_notices_mxid_avatar_url is not None + ): + join_profile = { + "displayname": self._config.server_notices_mxid_display_name, + "avatar_url": self._config.server_notices_mxid_avatar_url, + } + + requester = create_requester(system_mxid) + info = yield self._room_creation_handler.create_room( + requester, + config={ + "preset": RoomCreationPreset.PRIVATE_CHAT, + "name": self._config.server_notices_room_name, + "power_level_content_override": { + "users_default": -10, + }, + "invite": (user_id,) + }, + ratelimit=False, + creator_join_profile=join_profile, + ) + room_id = info['room_id'] + + logger.info("Created server notices room %s for %s", room_id, user_id) + defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py new file mode 100644 index 0000000000..5d23965f34 --- /dev/null +++ b/synapse/server_notices/server_notices_sender.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.server_notices.consent_server_notices import ConsentServerNotices + + +class ServerNoticesSender(object): + """A centralised place which sends server notices automatically when + Certain Events take place + """ + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer): + """ + # todo: it would be nice to make this more dynamic + self._consent_server_notices = ConsentServerNotices(hs) + + def on_user_syncing(self, user_id): + """Called when the user performs a sync operation. + + Args: + user_id (str): mxid of user who synced + + Returns: + Deferred + """ + return self._consent_server_notices.maybe_send_server_notice_to_user( + user_id, + ) + + def on_user_ip(self, user_id): + """Called on the master when a worker process saw a client request. + + Args: + user_id (str): mxid + + Returns: + Deferred + """ + # The synchrotrons use a stubbed version of ServerNoticesSender, so + # we check for notices to send to the user in on_user_ip as well as + # in on_user_syncing + return self._consent_server_notices.maybe_send_server_notice_to_user( + user_id, + ) diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py new file mode 100644 index 0000000000..4a133026c3 --- /dev/null +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + + +class WorkerServerNoticesSender(object): + """Stub impl of ServerNoticesSender which does nothing""" + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + + def on_user_syncing(self, user_id): + """Called when the user performs a sync operation. + + Args: + user_id (str): mxid of user who synced + + Returns: + Deferred + """ + return defer.succeed(None) + + def on_user_ip(self, user_id): + """Called on the master when a worker process saw a client request. + + Args: + user_id (str): mxid + + Returns: + Deferred + """ + raise AssertionError("on_user_ip unexpectedly called on worker") diff --git a/synapse/state.py b/synapse/state.py index 390799fbd5..033f55d967 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -14,23 +14,25 @@ # limitations under the License. +import hashlib +import logging +from collections import namedtuple + +from six import iteritems, iterkeys, itervalues + +from frozendict import frozendict + from twisted.internet import defer from synapse import event_auth -from synapse.util.logutils import log_function -from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR - -from collections import namedtuple -from frozendict import frozendict - -import logging -import hashlib +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.logutils import log_function +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -58,7 +60,11 @@ class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] def __init__(self, state, state_group, prev_group=None, delta_ids=None): + # dict[(str, str), str] map from (type, state_key) to event_id self.state = frozendict(state) + + # the ID of a state group if one and only one is involved. + # otherwise, None otherwise? self.state_group = state_group self.prev_group = prev_group @@ -81,31 +87,19 @@ class _StateCacheEntry(object): class StateHandler(object): - """ Responsible for doing state conflict resolution. + """Fetches bits of state from the stores, and does state resolution + where necessary """ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs - - # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = None - self.resolve_linearizer = Linearizer(name="state_resolve_lock") + self._state_resolution_handler = hs.get_state_resolution_handler() def start_caching(self): - logger.debug("start_caching") - - self._state_cache = ExpiringCache( - cache_name="state_cache", - clock=self.clock, - max_len=SIZE_OF_CACHE, - expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, - iterable=True, - reset_expiry_on_get=True, - ) - - self._state_cache.start() + # TODO: remove this shim + self._state_resolution_handler.start_caching() @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", @@ -127,7 +121,7 @@ class StateHandler(object): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: @@ -138,27 +132,36 @@ class StateHandler(object): defer.returnValue(event) return - state_map = yield self.store.get_events(state.values(), get_prev_content=False) + state_map = yield self.store.get_events(list(state.values()), + get_prev_content=False) state = { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map + key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } defer.returnValue(state) @defer.inlineCallbacks - def get_current_state_ids(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state_ids(self, room_id, latest_event_ids=None): + """Get the current state, or the state at a set of events, for a room + + Args: + room_id (str): + + latest_event_ids (iterable[str]|None): if given, the forward + extremities to resolve. If None, we look them up from the + database (via a cache) + + Returns: + Deferred[dict[(str, str), str)]]: the state dict, mapping from + (event_type, state_key) -> event_id + """ if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state - if event_type: - defer.returnValue(state.get((event_type, state_key))) - return - defer.returnValue(state) @defer.inlineCallbacks @@ -166,7 +169,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_user_in_room") - entry = yield self.resolve_state_groups(room_id, latest_event_ids) + entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state(room_id, entry) defer.returnValue(joined_users) @@ -175,7 +178,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_hosts_in_room") - entry = yield self.resolve_state_groups(room_id, latest_event_ids) + entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) joined_hosts = yield self.store.get_joined_hosts(room_id, entry) defer.returnValue(joined_hosts) @@ -183,8 +186,15 @@ class StateHandler(object): def compute_event_context(self, event, old_state=None): """Build an EventContext structure for the event. + This works out what the current state should be for the event, and + generates a new state group if necessary. + Args: event (synapse.events.EventBase): + old_state (dict|None): The state at the event if it can't be + calculated from existing events. This is normally only specified + when receiving an event from federation where we don't have the + prev events for, e.g. when backfilling. Returns: synapse.events.snapshot.EventContext: """ @@ -193,113 +203,158 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. - context = EventContext() if old_state: - context.prev_state_ids = { + prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } if event.is_state(): - context.current_state_ids = dict(context.prev_state_ids) + current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) - context.current_state_ids[key] = event.event_id + current_state_ids[key] = event.event_id else: - context.current_state_ids = context.prev_state_ids + current_state_ids = prev_state_ids else: - context.current_state_ids = {} - context.prev_state_ids = {} - context.prev_state_events = [] - context.state_group = self.store.get_next_state_group() + current_state_ids = {} + prev_state_ids = {} + + # We don't store state for outliers, so we don't generate a state + # group for it. + context = EventContext.with_state( + state_group=None, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + ) + defer.returnValue(context) if old_state: - context = EventContext() - context.prev_state_ids = { + # We already have the state, so we don't need to calculate it. + # Let's just correctly fill out the context and create a + # new state group for it. + + prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.prev_state_ids: - replaces = context.prev_state_ids[key] + if key in prev_state_ids: + replaces = prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces - context.current_state_ids = dict(context.prev_state_ids) - context.current_state_ids[key] = event.event_id + current_state_ids = dict(prev_state_ids) + current_state_ids[key] = event.event_id else: - context.current_state_ids = context.prev_state_ids + current_state_ids = prev_state_ids + + state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=None, + delta_ids=None, + current_state_ids=current_state_ids, + ) + + context = EventContext.with_state( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + ) - context.prev_state_events = [] defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], ) - curr_state = entry.state + prev_state_ids = entry.state + prev_group = None + delta_ids = None - context = EventContext() - context.prev_state_ids = curr_state if event.is_state(): - context.state_group = self.store.get_next_state_group() + # If this is a state event then we need to create a new state + # group for the state after this event. key = (event.type, event.state_key) - if key in context.prev_state_ids: - replaces = context.prev_state_ids[key] + if key in prev_state_ids: + replaces = prev_state_ids[key] event.unsigned["replaces_state"] = replaces - context.current_state_ids = dict(context.prev_state_ids) - context.current_state_ids[key] = event.event_id + current_state_ids = dict(prev_state_ids) + current_state_ids[key] = event.event_id if entry.state_group: - context.prev_group = entry.state_group - context.delta_ids = { + # If the state at the event has a state group assigned then + # we can use that as the prev group + prev_group = entry.state_group + delta_ids = { key: event.event_id } elif entry.prev_group: - context.prev_group = entry.prev_group - context.delta_ids = dict(entry.delta_ids) - context.delta_ids[key] = event.event_id + # If the state at the event only has a prev group, then we can + # use that as a prev group too. + prev_group = entry.prev_group + delta_ids = dict(entry.delta_ids) + delta_ids[key] = event.event_id + + state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=delta_ids, + current_state_ids=current_state_ids, + ) else: + current_state_ids = prev_state_ids + prev_group = entry.prev_group + delta_ids = entry.delta_ids + if entry.state_group is None: - entry.state_group = self.store.get_next_state_group() + entry.state_group = yield self.store.store_state_group( + event.event_id, + event.room_id, + prev_group=entry.prev_group, + delta_ids=entry.delta_ids, + current_state_ids=current_state_ids, + ) entry.state_id = entry.state_group - context.state_group = entry.state_group - context.current_state_ids = context.prev_state_ids - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids + state_group = entry.state_group + + context = EventContext.with_state( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=delta_ids, + ) - context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks - @log_function - def resolve_state_groups(self, room_id, event_ids): + def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. + Args: + room_id (str): + event_ids (list[str]): + Returns: - a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Deferred[_StateCacheEntry]: resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) + # map from state group id to the state in that state group (where + # 'state' is a map from state key to event id) + # dict[int, dict[(str, str), str]] state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) - - group_names = frozenset(state_groups_ids.keys()) - if len(group_names) == 1: - name, state_list = state_groups_ids.items().pop() + if len(state_groups_ids) == 1: + name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) @@ -310,6 +365,102 @@ class StateHandler(object): delta_ids=delta_ids, )) + result = yield self._state_resolution_handler.resolve_state_groups( + room_id, state_groups_ids, None, self._state_map_factory, + ) + defer.returnValue(result) + + def _state_map_factory(self, ev_ids): + return self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + + def resolve_events(self, state_sets, event): + logger.info( + "Resolving state for %s with %d groups", event.room_id, len(state_sets) + ) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } + + with Measure(self.clock, "state._resolve_events"): + new_state = resolve_events_with_state_map(state_set_ids, state_map) + + new_state = { + key: state_map[ev_id] for key, ev_id in iteritems(new_state) + } + + return new_state + + +class StateResolutionHandler(object): + """Responsible for doing state conflict resolution. + + Note that the storage layer depends on this handler, so all functions must + be storage-independent. + """ + def __init__(self, hs): + self.clock = hs.get_clock() + + # dict of set of event_ids -> _StateCacheEntry. + self._state_cache = None + self.resolve_linearizer = Linearizer(name="state_resolve_lock") + + def start_caching(self): + logger.debug("start_caching") + + self._state_cache = ExpiringCache( + cache_name="state_cache", + clock=self.clock, + max_len=SIZE_OF_CACHE, + expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, + reset_expiry_on_get=True, + ) + + self._state_cache.start() + + @defer.inlineCallbacks + @log_function + def resolve_state_groups( + self, room_id, state_groups_ids, event_map, state_map_factory, + ): + """Resolves conflicts between a set of state groups + + Always generates a new state group (unless we hit the cache), so should + not be called for a single state group + + Args: + room_id (str): room we are resolving for (used for logging) + state_groups_ids (dict[int, dict[(str, str), str]]): + map from state group id to the state in that state group + (where 'state' is a map from state key to event id) + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + Returns: + Deferred[_StateCacheEntry]: resolved state + """ + logger.debug( + "resolve_state_groups state_groups %s", + state_groups_ids.keys() + ) + + group_names = frozenset(state_groups_ids.keys()) + with (yield self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) @@ -320,112 +471,128 @@ class StateHandler(object): "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) - state = {} - for st in state_groups_ids.values(): - for key, e_id in st.items(): - state.setdefault(key, set()).add(e_id) - - conflicted_state = { - k: list(v) - for k, v in state.items() - if len(v) > 1 - } + # start by assuming we won't have any conflicted state, and build up the new + # state map by iterating through the state groups. If we discover a conflict, + # we give up and instead use `resolve_events_with_factory`. + # + # XXX: is this actually worthwhile, or should we just let + # resolve_events_with_factory do it? + new_state = {} + conflicted_state = False + for st in itervalues(state_groups_ids): + for key, e_id in iteritems(st): + if key in new_state: + conflicted_state = True + break + new_state[key] = e_id + if conflicted_state: + break if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events( - state_groups_ids.values(), - state_map_factory=lambda ev_ids: self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), + new_state = yield resolve_events_with_factory( + list(itervalues(state_groups_ids)), + event_map=event_map, + state_map_factory=state_map_factory, ) - else: - new_state = { - key: e_ids.pop() for key, e_ids in state.items() - } - state_group = None - new_state_event_ids = frozenset(new_state.values()) - for sg, events in state_groups_ids.items(): - if new_state_event_ids == frozenset(e_id for e_id in events): - state_group = sg - break + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. - # TODO: We want to create a state group for this set of events, to - # increase cache hits, but we need to make sure that it doesn't - # end up as a prev_group without being added to the database - - prev_group = None - delta_ids = None - for old_group, old_ids in state_groups_ids.iteritems(): - if not set(new_state) - set(old_ids): - n_delta_ids = { - k: v - for k, v in new_state.iteritems() - if old_ids.get(k) != v - } - if not delta_ids or len(n_delta_ids) < len(delta_ids): - prev_group = old_group - delta_ids = n_delta_ids - - cache = _StateCacheEntry( - state=new_state, - state_group=state_group, - prev_group=prev_group, - delta_ids=delta_ids, - ) + with Measure(self.clock, "state.create_group_ids"): + cache = _make_state_cache_entry(new_state, state_groups_ids) if self._state_cache is not None: self._state_cache[group_names] = cache defer.returnValue(cache) - def resolve_events(self, state_sets, event): - logger.info( - "Resolving state for %s with %d groups", event.room_id, len(state_sets) - ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } +def _make_state_cache_entry( + new_state, + state_groups_ids, +): + """Given a resolved state, and a set of input state groups, pick one to base + a new state group on (if any), and return an appropriately-constructed + _StateCacheEntry. - with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events(state_set_ids, state_map) + Args: + new_state (dict[(str, str), str]): resolved state map (mapping from + (type, state_key) to event_id) - new_state = { - key: state_map[ev_id] for key, ev_id in new_state.items() - } + state_groups_ids (dict[int, dict[(str, str), str]]): + map from state group id to the state in that state group + (where 'state' is a map from state key to event id) - return new_state + Returns: + _StateCacheEntry + """ + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. + + # first look for exact matches + new_state_event_ids = set(itervalues(new_state)) + for sg, state in iteritems(state_groups_ids): + if len(new_state_event_ids) != len(state): + continue + + old_state_event_ids = set(itervalues(state)) + if new_state_event_ids == old_state_event_ids: + # got an exact match. + return _StateCacheEntry( + state=new_state, + state_group=sg, + ) + + # TODO: We want to create a state group for this set of events, to + # increase cache hits, but we need to make sure that it doesn't + # end up as a prev_group without being added to the database + + # failing that, look for the closest match. + prev_group = None + delta_ids = None + + for old_group, old_state in iteritems(state_groups_ids): + n_delta_ids = { + k: v + for k, v in iteritems(new_state) + if old_state.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids + + return _StateCacheEntry( + state=new_state, + state_group=None, + prev_group=prev_group, + delta_ids=delta_ids, + ) def _ordered_events(events): def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() return sorted(events, key=key_func) -def resolve_events(state_sets, state_map_factory): +def resolve_events_with_state_map(state_sets, state_map): """ Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - state_map_factory(dict|callable): If callable, then will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. Otherwise, should be - a dict from event_id to event of all events in state_sets. + state_map(dict): a dict from event_id to event, for all events in + state_sets. Returns - dict[(str, str), synapse.events.FrozenEvent] is a map from - (type, state_key) to event. + dict[(str, str), str]: + a map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -434,13 +601,6 @@ def resolve_events(state_sets, state_map_factory): state_sets, ) - if callable(state_map_factory): - return _resolve_with_state_fac( - unconflicted_state, conflicted_state, state_map_factory - ) - - state_map = state_map_factory - auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -454,12 +614,28 @@ def _seperate(state_sets): """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. + + Args: + state_sets(iterable[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ - unconflicted_state = dict(state_sets[0]) + state_set_iterator = iter(state_sets) + unconflicted_state = dict(next(state_set_iterator)) conflicted_state = {} - for state_set in state_sets[1:]: - for key, value in state_set.iteritems(): + for state_set in state_set_iterator: + for key, value in iteritems(state_set): # Check if there is an unconflicted entry for the state key. unconflicted_value = unconflicted_state.get(key) if unconflicted_value is None: @@ -484,24 +660,63 @@ def _seperate(state_sets): @defer.inlineCallbacks -def _resolve_with_state_fac(unconflicted_state, conflicted_state, - state_map_factory): +def resolve_events_with_factory(state_sets, event_map, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + needed_events = set( event_id - for event_ids in conflicted_state.itervalues() + for event_ids in itervalues(conflicted_state) for event_id in event_ids ) + if event_map is not None: + needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d conflicted events", len(needed_events)) + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) - new_needed_events = set(auth_events.itervalues()) + new_needed_events = set(itervalues(auth_events)) new_needed_events -= needed_events + if event_map is not None: + new_needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d auth events", len(new_needed_events)) @@ -515,7 +730,7 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): auth_events = {} - for event_ids in conflicted_state.itervalues(): + for event_ids in itervalues(conflicted_state): for event_id in event_ids: if event_id in state_map: keys = event_auth.auth_types_for_event(state_map[event_id]) @@ -527,10 +742,10 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map): conflicted_state = {} - for key, event_ids in conflicted_state_ds.iteritems(): + for key, event_ids in iteritems(conflicted_state_ids): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] if len(events) > 1: conflicted_state[key] = events @@ -539,7 +754,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ auth_events = { key: state_map[ev_id] - for key, ev_id in auth_event_ids.items() + for key, ev_id in iteritems(auth_event_ids) if ev_id in state_map } @@ -547,12 +762,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ resolved_state = _resolve_state_events( conflicted_state, auth_events ) - except: + except Exception: logger.exception("Failed to resolve state") raise new_state = unconflicted_state_ids - for key, event in resolved_state.iteritems(): + for key, event in iteritems(resolved_state): new_state[key] = event.event_id return new_state @@ -577,7 +792,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in iteritems(conflicted_state): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = _resolve_auth_events( @@ -587,7 +802,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in iteritems(conflicted_state): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = _resolve_auth_events( @@ -597,7 +812,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in iteritems(conflicted_state): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = _resolve_normal_events( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b92472df33..ba88a54979 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,53 +14,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import datetime +import logging +import time + +from dateutil import tz +from synapse.api.constants import PresenceState from synapse.storage.devices import DeviceStore -from .appservice import ( - ApplicationServiceStore, ApplicationServiceTransactionStore -) -from ._base import LoggingTransaction +from synapse.storage.user_erasure_store import UserErasureStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .end_to_end_keys import EndToEndKeyStore +from .engines import PostgresEngine +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore from .events import EventsStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore from .registration import RegistrationStore +from .rejections import RejectionsStore from .room import RoomStore from .roommember import RoomMemberStore -from .stream import StreamStore -from .transactions import TransactionStore -from .keys import KeyStore -from .event_federation import EventFederationStore -from .pusher import PusherStore -from .push_rule import PushRuleStore -from .media_repository import MediaRepositoryStore -from .rejections import RejectionsStore -from .event_push_actions import EventPushActionsStore -from .deviceinbox import DeviceInboxStore - -from .state import StateStore -from .signatures import SignatureStore -from .filtering import FilteringStore -from .end_to_end_keys import EndToEndKeyStore - -from .receipts import ReceiptsStore from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stream import StreamStore from .tags import TagsStore -from .account_data import AccountDataStore -from .openid import OpenIdStore -from .client_ips import ClientIpStore +from .transactions import TransactionStore from .user_directory import UserDirectoryStore - -from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator -from .engines import PostgresEngine - -from synapse.api.constants import PresenceState -from synapse.util.caches.stream_change_cache import StreamChangeCache - - -import logging - +from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator logger = logging.getLogger(__name__) @@ -88,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, + GroupServerStore, + UserErasureStore, ): def __init__(self, db_conn, hs): @@ -103,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -123,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") @@ -135,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id", + ) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( @@ -143,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore, else: self._cache_id_gen = None - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - - account_max = self._account_data_id_gen.get_current_token() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, - ) - self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( @@ -177,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", @@ -223,6 +185,7 @@ class DataStore(RoomMemberStore, RoomStore, "DeviceListFederationStreamChangeCache", device_list_max, ) + events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -235,24 +198,25 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=curr_state_delta_prefill, ) - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - after_callbacks=[], - final_callbacks=[], + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", min_group_updates_id, + prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - super(DataStore, self).__init__(hs) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + + super(DataStore, self).__init__(db_conn, hs) def take_presence_startup_info(self): active_on_startup = self._presence_on_startup @@ -281,13 +245,12 @@ class DataStore(RoomMemberStore, RoomStore, return [UserPresenceState(**row) for row in rows] - @defer.inlineCallbacks def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. """ def _count_users(txn): - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24), + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) sql = """ SELECT COALESCE(count(*), 0) FROM ( @@ -301,8 +264,154 @@ class DataStore(RoomMemberStore, RoomStore, count, = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + return self.runInteraction("count_users", _count_users) + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + for row in txn: + if row[0] is 'unknown': + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + count, = txn.fetchone() + results['all'] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = datetime.datetime.utcnow() + today_start = datetime.datetime(now.year, now.month, + now.day, tzinfo=tz.tzutc()) + return int(time.mktime(today_start.timetuple())) * 1000 + + def generate_user_daily_visits(self): + """ + Generates daily visit data for use in cohort/ retention analysis + """ + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self.clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute(sql, ( + yesterday_start, yesterday_start, + self._last_user_visit_update, today_start + )) + self._last_user_visit_update = today_start + + txn.execute(sql, ( + today_start, today_start, + self._last_user_visit_update, + now + )) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + return self.runInteraction("generate_user_daily_visits", + _generate_user_daily_visits) def get_users(self): """Function to reterive a list of users in users table. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6f54036d67..1d41d8d445 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -13,36 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys +import threading +import time -from synapse.api.errors import StoreError -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches import CACHE_SIZE_FACTOR -from synapse.util.caches.dictionary_cache import DictionaryCache -from synapse.util.caches.descriptors import Cache -from synapse.storage.engines import PostgresEngine -import synapse.metrics +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range +from prometheus_client import Histogram from twisted.internet import defer -import sys -import time -import threading - +from synapse.api.errors import StoreError +from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import Cache +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2**63 - 1 + sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") -metrics = synapse.metrics.get_metrics_for("synapse.storage") - -sql_scheduling_timer = metrics.register_distribution("schedule_time") - -sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) -sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) class LoggingTransaction(object): @@ -50,16 +52,16 @@ class LoggingTransaction(object): passed to the constructor. Adds logging and metrics to the .execute() method.""" __slots__ = [ - "txn", "name", "database_engine", "after_callbacks", "final_callbacks", + "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", ] def __init__(self, txn, name, database_engine, after_callbacks, - final_callbacks): + exception_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "final_callbacks", final_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the @@ -68,8 +70,8 @@ class LoggingTransaction(object): """ self.after_callbacks.append((callback, args, kwargs)) - def call_finally(self, callback, *args, **kwargs): - self.final_callbacks.append((callback, args, kwargs)) + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -103,11 +105,11 @@ class LoggingTransaction(object): "[SQL values] {%s} %r", self.name, args[0] ) - except: + except Exception: # Don't let logging failures stop SQL from working pass - start = time.time() * 1000 + start = time.time() try: return func( @@ -117,9 +119,9 @@ class LoggingTransaction(object): logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise finally: - msecs = (time.time() * 1000) - start - sql_logger.debug("[SQL time] {%s} %f", self.name, msecs) - sql_query_timer.inc_by(msecs, sql.split()[0]) + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) class PerformanceCounters(object): @@ -129,7 +131,7 @@ class PerformanceCounters(object): def update(self, key, start_time, end_time=None): if end_time is None: - end_time = time.time() * 1000 + end_time = time.time() duration = end_time - start_time count, cum_time = self.current_counters.get(key, (0, 0)) count += 1 @@ -139,7 +141,7 @@ class PerformanceCounters(object): def interval(self, interval_duration, limit=3): counters = [] - for name, (count, cum_time) in self.current_counters.iteritems(): + for name, (count, cum_time) in iteritems(self.current_counters): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append(( (cum_time - prev_time) / interval_duration, @@ -162,7 +164,7 @@ class PerformanceCounters(object): class SQLBaseStore(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() @@ -180,10 +182,6 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR - ) - self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -221,14 +219,14 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, final_callbacks, - logging_context, func, *args, **kwargs): - start = time.time() * 1000 + def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, + func, *args, **kwargs): + start = time.time() txn_id = self._TXN_ID # We don't really need these to be unique, so lets stop it from # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) name = "%s-%x" % (desc, txn_id, ) @@ -242,7 +240,7 @@ class SQLBaseStore(object): txn = conn.cursor() txn = LoggingTransaction( txn, name, self.database_engine, after_callbacks, - final_callbacks, + exception_callbacks, ) r = func(txn, *args, **kwargs) conn.commit() @@ -283,73 +281,85 @@ class SQLBaseStore(object): logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: - end = time.time() * 1000 + end = time.time() duration = end - start - if logging_context is not None: - logging_context.add_database_transaction(duration) + LoggingContext.current_context().add_database_transaction(duration) - transaction_logger.debug("[TXN END] {%s} %f", name, duration) + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) self._current_txn_total_time += duration self._txn_perf_counters.update(desc, start, end) - sql_txn_timer.inc_by(duration, desc) + sql_txn_timer.labels(desc).observe(duration) @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() - - start_time = time.time() * 1000 + """Starts a transaction on the database and runs a given function - after_callbacks = [] - final_callbacks = [] + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. - def inner_func(conn, *args, **kwargs): - with LoggingContext("runInteraction") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - current_context.copy_to(context) - return self._new_transaction( - conn, desc, after_callbacks, final_callbacks, current_context, - func, *args, **kwargs - ) + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] try: - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield self.runWithConnection( + self._new_transaction, + desc, after_callbacks, exception_callbacks, func, + *args, **kwargs + ) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) - finally: - for after_callback, after_args, after_kwargs in final_callbacks: + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: after_callback(*after_args, **after_kwargs) + raise defer.returnValue(result) @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() + """Wraps the .runWithConnection() method on the underlying db_pool. - start_time = time.time() * 1000 + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warn( + "Starting db connection from sentinel context: metrics will be lost", + ) + parent_context = None + + start_time = time.time() def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = time.time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() - current_context.copy_to(context) - return func(conn, *args, **kwargs) with PreserveLoggingContext(): @@ -368,7 +378,7 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(column[0]) for column in cursor.description) + col_headers = list(intern(str(column[0])) for column in cursor.description) results = list( dict(zip(col_headers, row)) for row in cursor ) @@ -475,23 +485,53 @@ class SQLBaseStore(object): txn.executemany(sql, vals) + @defer.inlineCallbacks def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values - insertion_values (dict): key/values to use when inserting + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. Returns: Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self.runInteraction( - desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock - ) + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self._simple_upsert_txn, table, keyvalues, values, insertion_values, + lock=lock + ) + defer.returnValue(result) + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warn( + "IntegrityError when upserting into %s; retrying: %s", + table, e + ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): @@ -499,37 +539,38 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) - # Try to update + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), " AND ".join("%s = ?" % (k,) for k in keyvalues) ) - sqlargs = values.values() + keyvalues.values() + sqlargs = list(values.values()) + list(keyvalues.values()) txn.execute(sql, sqlargs) - if txn.rowcount == 0: - # We didn't update and rows so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) - ) - txn.execute(sql, allvalues.values()) + # We didn't update any rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) - return True - else: - return False + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name @@ -582,20 +623,18 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - sql = ( - "SELECT %(retcol)s FROM %(table)s %(where)s" + "SELECT %(retcol)s FROM %(table)s" ) % { "retcol": retcol, "table": table, - "where": where, } - txn.execute(sql, keyvalues.values()) + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) return [r[0] for r in txn] @@ -606,7 +645,7 @@ class SQLBaseStore(object): Args: table (str): table name - keyvalues (dict): column names and values to select the rows with + keyvalues (dict|None): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: @@ -657,7 +696,7 @@ class SQLBaseStore(object): table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - txn.execute(sql, keyvalues.values()) + txn.execute(sql, list(keyvalues.values())) else: sql = "SELECT %s FROM %s" % ( ", ".join(retcols), @@ -688,9 +727,12 @@ class SQLBaseStore(object): if not iterable: defer.returnValue(results) + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + chunks = [ - iterable[i:i + batch_size] - for i in xrange(0, len(iterable), batch_size) + it_list[i:i + batch_size] + for i in range(0, len(it_list), batch_size) ] for chunk in chunks: rows = yield self.runInteraction( @@ -730,7 +772,7 @@ class SQLBaseStore(object): ) values.extend(iterable) - for key, value in keyvalues.iteritems(): + for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) values.append(value) @@ -743,6 +785,33 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) + def _simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, + self._simple_update_txn, + table, keyvalues, updatevalues, + ) + + @staticmethod + def _simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute( + update_sql, + list(updatevalues.values()) + list(keyvalues.values()) + ) + + return txn.rowcount + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for @@ -768,27 +837,13 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - @staticmethod - def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute( - update_sql, - updatevalues.values() + keyvalues.values() - ) + @classmethod + def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - if txn.rowcount == 0: + if rowcount == 0: raise StoreError(404, "No row found") - if txn.rowcount > 1: + if rowcount > 1: raise StoreError(500, "More than one row matched") @staticmethod @@ -800,7 +855,7 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues) ) - txn.execute(select_sql, keyvalues.values()) + txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: @@ -838,7 +893,7 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - txn.execute(sql, keyvalues.values()) + txn.execute(sql, list(keyvalues.values())) if txn.rowcount == 0: raise StoreError(404, "No row found") if txn.rowcount > 1: @@ -856,7 +911,7 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - return txn.execute(sql, keyvalues.values()) + return txn.execute(sql, list(keyvalues.values())) def _simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( @@ -888,7 +943,7 @@ class SQLBaseStore(object): ) values.extend(iterable) - for key, value in keyvalues.iteritems(): + for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) values.append(value) @@ -928,7 +983,7 @@ class SQLBaseStore(object): txn.close() if cache: - min_val = min(cache.itervalues()) + min_val = min(itervalues(cache)) else: min_val = max_value @@ -951,7 +1006,8 @@ class SQLBaseStore(object): # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() - txn.call_finally(ctx.__exit__, None, None, None) + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) self._simple_insert_txn( @@ -1042,7 +1098,7 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues), " ? ASC LIMIT ? OFFSET ?" ) - txn.execute(sql, keyvalues.values() + pagevalues) + txn.execute(sql, list(keyvalues.values()) + list(pagevalues)) else: sql = "SELECT %s FROM %s ORDER BY %s" % ( ", ".join(retcols), diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index aa84ffc2b0..bbc3355c73 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer +import abc +import logging -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from canonicaljson import json -import ujson as json -import logging +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: @@ -85,25 +114,7 @@ class AccountDataStore(SQLBaseStore): else: defer.returnValue(None) - @cachedList(cached_method_name="get_global_account_data_by_type_for_user", - num_args=2, list_name="user_ids", inlineCallbacks=True) - def get_global_account_data_by_type_for_users(self, data_type, user_ids): - rows = yield self._simple_select_many_batch( - table="account_data", - column="user_id", - iterable=user_ids, - keyvalues={ - "account_data_type": data_type, - }, - retcols=("user_id", "content",), - desc="get_global_account_data_by_type_for_users", - ) - - defer.returnValue({ - row["user_id"]: json.loads(row["content"]) if row["content"] else None - for row in rows - }) - + @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -127,6 +138,38 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", + get_account_data_for_room_and_type_txn, + ) + def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server @@ -209,6 +252,36 @@ class AccountDataStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -222,9 +295,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,18 +310,23 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type,), content, ) result = self._account_data_id_gen.get_current_token() @@ -263,9 +344,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,37 +358,43 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index c63935cb07..9f12b360bc 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,15 +15,16 @@ # limitations under the License. import logging import re -import simplejson as json + +from canonicaljson import json + from twisted.internet import defer -from synapse.api.constants import Membership from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.roommember import RoomsForUser -from ._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -46,17 +48,16 @@ def _make_exclusive_regex(services_cache): return exclusive_user_regex -class ApplicationServiceStore(SQLBaseStore): - - def __init__(self, hs): - super(ApplicationServiceStore, self).__init__(hs) - self.hostname = hs.hostname +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + def get_app_services(self): return self.services_cache @@ -99,83 +100,30 @@ class ApplicationServiceStore(SQLBaseStore): return service return None - def get_app_service_rooms(self, service): - """Get a list of RoomsForUser for this application service. - - Application services may be "interested" in lots of rooms depending on - the room ID, the room aliases, or the members in the room. This function - takes all of these into account and returns a list of RoomsForUser which - represent the entire list of room IDs that this application service - wants to know about. + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. Args: - service: The application service to get a room list for. + as_id (str): The application service ID. Returns: - A list of RoomsForUser. + synapse.appservice.ApplicationService or None. """ - return self.runInteraction( - "get_app_service_rooms", - self._get_app_service_rooms_txn, - service, - ) - - def _get_app_service_rooms_txn(self, txn, service): - # get all rooms matching the room ID regex. - room_entries = self._simple_select_list_txn( - txn=txn, table="rooms", keyvalues=None, retcols=["room_id"] - ) - matching_room_list = set([ - r["room_id"] for r in room_entries if - service.is_interested_in_room(r["room_id"]) - ]) - - # resolve room IDs for matching room alias regex. - room_alias_mappings = self._simple_select_list_txn( - txn=txn, table="room_aliases", keyvalues=None, - retcols=["room_id", "room_alias"] - ) - matching_room_list |= set([ - r["room_id"] for r in room_alias_mappings if - service.is_interested_in_alias(r["room_alias"]) - ]) - - # get all rooms for every user for this AS. This is scoped to users on - # this HS only. - user_list = self._simple_select_list_txn( - txn=txn, table="users", keyvalues=None, retcols=["name"] - ) - user_list = [ - u["name"] for u in user_list if - service.is_interested_in_user(u["name"]) - ] - rooms_for_user_matching_user_id = set() # RoomsForUser list - for user_id in user_list: - # FIXME: This assumes this store is linked with RoomMemberStore :( - rooms_for_user = self._get_rooms_for_user_where_membership_is_txn( - txn=txn, - user_id=user_id, - membership_list=[Membership.JOIN] - ) - rooms_for_user_matching_user_id |= set(rooms_for_user) - - # make RoomsForUser tuples for room ids and aliases which are not in the - # main rooms_for_user_list - e.g. they are rooms which do not have AS - # registered users in it. - known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id] - missing_rooms_for_user = [ - RoomsForUser(r, service.sender, "join") for r in - matching_room_list if r not in known_room_ids - ] - rooms_for_user_matching_user_id |= set(missing_rooms_for_user) - - return rooms_for_user_matching_user_id + for service in self.services_cache: + if service.id == as_id: + return service + return None -class ApplicationServiceTransactionStore(SQLBaseStore): +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass - def __init__(self, hs): - super(ApplicationServiceTransactionStore, self).__init__(hs) +class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, + EventsWorkerStore): @defer.inlineCallbacks def get_appservices_by_state(self, state): """Get a list of application services based on their state. @@ -420,3 +368,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore): events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 7157fb1dfb..5fe1ca2de7 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,15 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.async -from ._base import SQLBaseStore -from . import engines +import logging + +from canonicaljson import json from twisted.internet import defer -import ujson as json -import logging +from synapse.metrics.background_process_metrics import run_as_background_process + +from . import engines +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -80,25 +82,30 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs): - super(BackgroundUpdateStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(BackgroundUpdateStore, self).__init__(db_conn, hs) self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} + self._all_done = False - @defer.inlineCallbacks def start_doing_background_updates(self): - logger.info("Starting background schema updates") + run_as_background_process( + "background_updates", self._run_background_updates, + ) + @defer.inlineCallbacks + def _run_background_updates(self): + logger.info("Starting background schema updates") while True: - yield synapse.util.async.sleep( + yield self.hs.get_clock().sleep( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) try: result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) - except: + except Exception: logger.exception("Error doing update") else: if result is None: @@ -106,9 +113,41 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) + self._all_done = True defer.returnValue(None) @defer.inlineCallbacks + def has_completed_background_updates(self): + """Check if all the background updates have completed + + Returns: + Deferred[bool]: True if all background updates have completed + """ + # if we've previously determined that there is nothing left to do, that + # is easy + if self._all_done: + defer.returnValue(True) + + # obviously, if we have things in our queue, we're not done. + if self._background_update_queue: + defer.returnValue(False) + + # otherwise, check if there are updates to be run. This is important, + # as we may be running on a worker which doesn't perform the bg updates + # itself, but still wants to wait for them to happen. + updates = yield self._simple_select_onecol( + "background_updates", + keyvalues=None, + retcol="1", + desc="check_background_updates", + ) + if not updates: + self._all_done = True + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): """Does some amount of work on the next queued background update @@ -209,6 +248,25 @@ class BackgroundUpdateStore(SQLBaseStore): """ self._background_update_handlers[update_name] = update_handler + def register_noop_background_update(self, update_name): + """Register a noop handler for a background update. + + This is useful when we previously did a background update, but no + longer wish to do the update. In this case the background update should + be removed from the schema delta files, but there may still be some + users who have the background update queued, so this method should + also be called to clear the update. + + Args: + update_name (str): Name of update + """ + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update(update_name) + defer.returnValue(1) + + self.register_background_update_handler(update_name, noop_update) + def register_background_index_update(self, update_name, index_name, table, columns, where_clause=None, unique=False, @@ -269,7 +327,7 @@ class BackgroundUpdateStore(SQLBaseStore): # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 + # until 3.8, and wheezy and CentOS 7 have 3.7 # # We assume that sqlite doesn't give us invalid indices; however # we may still end up with the index existing but the diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index fc468ea185..77ae10da3d 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,13 +15,15 @@ import logging -from twisted.internet import defer, reactor +from six import iteritems -from ._base import Cache -from . import background_updates +from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import CACHE_SIZE_FACTOR +from . import background_updates +from ._base import Cache logger = logging.getLogger(__name__) @@ -32,14 +34,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): + def __init__(self, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR, ) - super(ClientIpStore, self).__init__(hs) + super(ClientIpStore, self).__init__(db_conn, hs) self.register_background_index_update( "user_ips_device_index", @@ -48,17 +50,35 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id", "last_seen"], ) + self.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + + self.register_background_index_update( + "user_ips_last_seen_only_index", + index_name="user_ips_last_seen_only", + table="user_ips", + columns=["last_seen"], + ) + # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) self._batch_row_update = {} self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 ) - reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) - def insert_client_ip(self, user, access_token, ip, user_agent, device_id): - now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, ip) + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, + now=None): + if not now: + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -74,16 +94,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): - to_update = self._batch_row_update - self._batch_row_update = {} - return self.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + def update(): + to_update = self._batch_row_update + self._batch_row_update = {} + return self.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, + to_update, + ) + + run_as_background_process( + "update_client_ips", update, ) def _update_client_ips_batch_txn(self, txn, to_update): self.database_engine.lock_table(txn, "user_ips") - for entry in to_update.iteritems(): + for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry self._simple_upsert_txn( @@ -215,5 +241,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "user_agent": user_agent, "last_seen": last_seen, } - for (access_token, ip), (user_agent, last_seen) in results.iteritems() + for (access_token, ip), (user_agent, last_seen) in iteritems(results) )) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 0b62b493d5..73646da025 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -14,14 +14,14 @@ # limitations under the License. import logging -import ujson -from twisted.internet import defer +from canonicaljson import json -from .background_updates import BackgroundUpdateStore +from twisted.internet import defer from synapse.util.caches.expiringcache import ExpiringCache +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) @@ -29,8 +29,8 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, hs): - super(DeviceInboxStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_inbox_stream_index", @@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore): ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = ujson.dumps(edu) + edu_json = json.dumps(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore): " WHERE user_id = ?" ) txn.execute(sql, (user_id,)) - message_json = ujson.dumps(messages_by_device["*"]) + message_json = json.dumps(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore): # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = ujson.dumps(messages_by_device[device]) + message_json = json.dumps(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: @@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(json.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) @@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(json.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index bb27fd1f70..cc3cdf2ebc 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,21 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import ujson as json + +from six import iteritems, itervalues + +from canonicaljson import json from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, Cache -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from ._base import Cache, SQLBaseStore logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): - def __init__(self, hs): - super(DeviceStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. @@ -245,17 +248,31 @@ class DeviceStore(SQLBaseStore): def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, content, stream_id): - self._simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "content": json.dumps(content), - } - ) + if content.get("deleted"): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + ) + + txn.call_after( + self.device_id_exists_cache.invalidate, (user_id, device_id,) + ) + else: + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) @@ -360,10 +377,10 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, []) if len(query_map) >= 20: - now_stream_id = max(stream_id for stream_id in query_map.itervalues()) + now_stream_id = max(stream_id for stream_id in itervalues(query_map)) devices = self._get_e2e_device_keys_txn( - txn, query_map.keys(), include_all_devices=True + txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True ) prev_sent_id_sql = """ @@ -373,13 +390,13 @@ class DeviceStore(SQLBaseStore): """ results = [] - for user_id, user_devices in devices.iteritems(): + for user_id, user_devices in iteritems(devices): # The prev_id for the first row is always the last row before # `from_stream_id` txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) rows = txn.fetchall() prev_id = rows[0][0] - for device_id, device in user_devices.iteritems(): + for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -390,12 +407,15 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = device.get("key_json", None) - if key_json: - result["keys"] = json.loads(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True results.append(result) @@ -483,7 +503,7 @@ class DeviceStore(SQLBaseStore): if devices: user_devices = devices[user_id] results = [] - for device_id, device in user_devices.iteritems(): + for device_id, device in iteritems(user_devices): result = { "device_id": device_id, } diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 79e7c540ad..808194236a 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached - -from synapse.api.errors import SynapseError +from collections import namedtuple from twisted.internet import defer -from collections import namedtuple +from synapse.api.errors import SynapseError +from synapse.util.caches.descriptors import cached +from ._base import SQLBaseStore RoomAliasMapping = namedtuple( "RoomAliasMapping", @@ -29,8 +28,7 @@ RoomAliasMapping = namedtuple( ) -class DirectoryStore(SQLBaseStore): - +class DirectoryWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): """ Get's the room_id and server list for a given room_alias @@ -69,6 +67,28 @@ class DirectoryStore(SQLBaseStore): RoomAliasMapping(room_id, room_alias.to_string(), servers) ) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers @@ -116,17 +136,6 @@ class DirectoryStore(SQLBaseStore): ) defer.returnValue(ret) - def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( - table="room_aliases", - keyvalues={ - "room_alias": room_alias, - }, - retcol="creator", - desc="get_room_alias_creator", - allow_none=True - ) - @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( @@ -135,7 +144,6 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): @@ -160,17 +168,12 @@ class DirectoryStore(SQLBaseStore): (room_alias.to_string(),) ) - return room_id - - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) ) + return room_id + def update_aliases_for_room(self, old_room_id, new_room_id, creator): def _update_aliases_for_room_txn(txn): sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2cebb203c6..523b4360c3 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer from synapse.util.caches.descriptors import cached -from canonicaljson import encode_canonical_json -import ujson as json - from ._base import SQLBaseStore @@ -63,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_device_keys(self, query_list, include_all_devices=False): + def get_e2e_device_keys( + self, query_list, include_all_devices=False, + include_deleted_devices=False, + ): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. include_all_devices (bool): whether to include entries for devices that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". @@ -78,19 +85,28 @@ class EndToEndKeyStore(SQLBaseStore): results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, + query_list, include_all_devices, include_deleted_devices, ) - for user_id, device_keys in results.iteritems(): - for device_id, device_info in device_keys.iteritems(): + for user_id, device_keys in iteritems(results): + for device_id, device_info in iteritems(device_keys): device_info["keys"] = json.loads(device_info.pop("key_json")) defer.returnValue(results) - def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, + include_deleted_devices=False, + ): query_clauses = [] query_params = [] + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) @@ -118,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore): result = {} for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) result.setdefault(row["user_id"], {})[row["device_id"]] = row + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + return result @defer.inlineCallbacks diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 338b495611..e2f9de8451 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import platform + from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite3 import Sqlite3Engine -import importlib - - SUPPORTED_MODULE = { "sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine, @@ -31,6 +31,10 @@ def create_engine(database_config): engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: + # pypy requires psycopg2cffi rather than psycopg2 + if (name == "psycopg2" and + platform.python_implementation() == "PyPy"): + name = "psycopg2cffi" module = importlib.import_module(name) return engine_class(module, database_config) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a6ae79dfad..8a0386c1a4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -62,3 +62,9 @@ class PostgresEngine(object): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + txn.execute("SELECT nextval('state_group_id_seq')") + return txn.fetchone()[0] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 755c9a1f07..19949fc474 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - import struct +import threading + +from synapse.storage.prepare_database import prepare_database class Sqlite3Engine(object): @@ -24,6 +25,11 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + def check_database(self, txn): pass @@ -43,6 +49,19 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index e8133de2fa..8d366d1b91 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,45 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import random -from twisted.internet import defer +from six.moves import range +from six.moves.queue import Empty, PriorityQueue -from ._base import SQLBaseStore -from synapse.api.errors import StoreError -from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 -import logging -from Queue import PriorityQueue, Empty +from twisted.internet import defer +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.storage.signatures import SignatureWorkerStore +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -class EventFederationStore(SQLBaseStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. - - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ - - EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - - def __init__(self, hs): - super(EventFederationStore, self).__init__(hs) - - self.register_background_update_handler( - self.EVENT_AUTH_STATE_ONLY, - self._background_delete_non_state_event_auth, - ) - - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, + SQLBaseStore): def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -97,7 +79,7 @@ class EventFederationStore(SQLBaseStore): front_list = list(front) chunks = [ front_list[x:x + 100] - for x in xrange(0, len(front), 100) + for x in range(0, len(front), 100) ] for chunk in chunks: txn.execute( @@ -152,7 +134,47 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -201,22 +223,6 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, @@ -228,88 +234,6 @@ class EventFederationStore(SQLBaseStore): return int(min_depth) if min_depth is not None else None - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth and depth >= min_depth: - return - - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id, _ in ev.prev_events - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany(query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ]) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -371,28 +295,6 @@ class EventFederationStore(SQLBaseStore): get_forward_extremeties_for_room_txn ) - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = (""" - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """) - txn.execute( - sql, - (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) - ) - return self.runInteraction( - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn - ) - def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -522,6 +424,135 @@ class EventFederationStore(SQLBaseStore): return event_results + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + if min_depth and depth >= min_depth: + return + + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={ + "room_id": room_id, + }, + values={ + "min_depth": depth, + }, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id, _ in ev.prev_events + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) for ev in events + if not ev.internal_metadata.is_outlier() + ] + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def clean_room_for_join(self, room_id): return self.runInteraction( "clean_room_for_join", diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d6d8723b4a..29b511ae5e 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +import logging + +from six import iteritems + +from canonicaljson import json + from twisted.internet import defer -from synapse.util.async import sleep -from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.types import RoomStreamToken -from .stream import lower_bound -import logging -import ujson as json +from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -62,59 +64,29 @@ def _deserialize_action(actions, is_highlight): return DEFAULT_NOTIF_ACTION -class EventPushActionsStore(SQLBaseStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, hs): - super(EventPushActionsStore, self).__init__(hs) +class EventPushActionsWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + after_callbacks=[], + exception_callbacks=[], ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 60 * 1000 + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 ) - - def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): - """ - Args: - event: the event set actions for - tuples: list of tuples of (user_id, actions) - """ - values = [] - for uid, actions in tuples: - is_highlight = 1 if _action_has_highlight(actions) else 0 - - values.append({ - 'room_id': event.room_id, - 'event_id': event.event_id, - 'user_id': uid, - 'actions': _serialize_action(actions, is_highlight), - 'stream_ordering': event.internal_metadata.stream_ordering, - 'topological_ordering': event.depth, - 'notif': 1, - 'highlight': is_highlight, - }) - - for uid, __ in tuples: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid) - ) - self._simple_insert_many_txn(txn, "event_push_actions", values) + self._rotate_delay = 3 + self._rotate_count = 10000 @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( @@ -130,7 +102,7 @@ class EventPushActionsStore(SQLBaseStore): def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, last_read_event_id): sql = ( - "SELECT stream_ordering, topological_ordering" + "SELECT stream_ordering" " FROM events" " WHERE room_id = ? AND event_id = ?" ) @@ -142,17 +114,12 @@ class EventPushActionsStore(SQLBaseStore): return {"notify_count": 0, "highlight_count": 0} stream_ordering = results[0][0] - topological_ordering = results[0][1] return self._get_unread_counts_by_pos_txn( - txn, room_id, user_id, topological_ordering, stream_ordering + txn, room_id, user_id, stream_ordering ) - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering, - stream_ordering): - token = RoomStreamToken( - topological_ordering, stream_ordering - ) + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): # First get number of notifications. # We don't need to put a notif=1 clause as all rows always have @@ -163,10 +130,10 @@ class EventPushActionsStore(SQLBaseStore): " WHERE" " user_id = ?" " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + " AND stream_ordering > ?" + ) - txn.execute(sql, (user_id, room_id)) + txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() notify_count = row[0] if row else 0 @@ -186,10 +153,10 @@ class EventPushActionsStore(SQLBaseStore): " highlight = 1" " AND user_id = ?" " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + " AND stream_ordering > ?" + ) - txn.execute(sql, (user_id, room_id)) + txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() highlight_count = row[0] if row else 0 @@ -240,7 +207,6 @@ class EventPushActionsStore(SQLBaseStore): " ep.highlight " " FROM (" " SELECT room_id," - " MAX(topological_ordering) as topological_ordering," " MAX(stream_ordering) as stream_ordering" " FROM events" " INNER JOIN receipts_linearized USING (room_id, event_id)" @@ -250,13 +216,7 @@ class EventPushActionsStore(SQLBaseStore): " event_push_actions AS ep" " WHERE" " ep.room_id = rl.room_id" - " AND (" - " ep.topological_ordering > rl.topological_ordering" - " OR (" - " ep.topological_ordering = rl.topological_ordering" - " AND ep.stream_ordering > rl.stream_ordering" - " )" - " )" + " AND ep.stream_ordering > rl.stream_ordering" " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" @@ -349,7 +309,6 @@ class EventPushActionsStore(SQLBaseStore): " ep.highlight, e.received_ts" " FROM (" " SELECT room_id," - " MAX(topological_ordering) as topological_ordering," " MAX(stream_ordering) as stream_ordering" " FROM events" " INNER JOIN receipts_linearized USING (room_id, event_id)" @@ -360,13 +319,7 @@ class EventPushActionsStore(SQLBaseStore): " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" " ep.room_id = rl.room_id" - " AND (" - " ep.topological_ordering > rl.topological_ordering" - " OR (" - " ep.topological_ordering = rl.topological_ordering" - " AND ep.stream_ordering > rl.stream_ordering" - " )" - " )" + " AND ep.stream_ordering > rl.stream_ordering" " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" @@ -432,6 +385,290 @@ class EventPushActionsStore(SQLBaseStore): # Now return the first `limit` defer.returnValue(notifs[:limit]) + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany(sql, ( + _gen_entry(user_id, actions) + for user_id, actions in iteritems(user_id_actions) + )) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + @defer.inlineCallbacks + def remove_push_actions_from_staging(self, event_id): + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + + Args: + event_id (str) + """ + + try: + res = yield self._simple_delete( + table="event_push_actions_staging", + keyvalues={ + "event_id": event_id, + }, + desc="remove_push_actions_from_staging", + ) + defer.returnValue(res) + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure", + ) + + @defer.inlineCallbacks + def _find_stream_orderings_for_times(self): + yield self.runInteraction( + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", + self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, <none>, 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, + all_events_and_contexts): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany(sql, ( + ( + event.room_id, event.internal_metadata.stream_ordering, + event.depth, event.event_id, + ) + for event, _ in events_and_contexts + )) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + }, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid,) + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ( + (event.event_id,) + for event, _ in all_events_and_contexts + ) + ) + @defer.inlineCallbacks def get_push_actions_for_user(self, user_id, before=None, limit=50, only_highlight=False): @@ -509,10 +746,10 @@ class EventPushActionsStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - topological_ordering, stream_ordering): + stream_ordering): """ Purges old push actions for a user and room before a given - topological_ordering. + stream_ordering. We however keep a months worth of highlighted notifications, so that users can still get a list of recent highlights. @@ -521,7 +758,7 @@ class EventPushActionsStore(SQLBaseStore): txn: The transcation room_id: Room ID to delete from user_id: user ID to delete for - topological_ordering: The lowest topological ordering which will + stream_ordering: The lowest stream ordering which will not be deleted. """ txn.call_after( @@ -540,9 +777,9 @@ class EventPushActionsStore(SQLBaseStore): txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " - " topological_ordering <= ?" + " stream_ordering <= ?" " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago) ) txn.execute(""" @@ -551,69 +788,6 @@ class EventPushActionsStore(SQLBaseStore): """, (room_id, user_id, stream_ordering)) @defer.inlineCallbacks - def _find_stream_orderings_for_times(self): - yield self.runInteraction( - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn - ) - - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", - self.stream_ordering_month_ago - ) - logger.info("Searching for stream ordering 1 day ago") - self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 day ago: it's %d", - self.stream_ordering_day_ago - ) - - def _find_first_stream_ordering_after_ts_txn(self, txn, ts): - """ - Find the stream_ordering of the first event that was received after - a given timestamp. This is relatively slow as there is no index on - received_ts but we can then use this to delete push actions before - this. - - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering - """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - - if max_stream_ordering is None: - return 0 - - range_start = 0 - range_end = max_stream_ordering - - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering > ?" - " ORDER BY stream_ordering" - " LIMIT 1" - ) - - while range_end - range_start > 1: - middle = int((range_end + range_start) / 2) - txn.execute(sql, (middle,)) - middle_ts = txn.fetchone()[0] - if ts > middle_ts: - range_start = middle - else: - range_end = middle - - return range_end - - @defer.inlineCallbacks def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return @@ -629,7 +803,7 @@ class EventPushActionsStore(SQLBaseStore): ) if caught_up: break - yield sleep(5) + yield self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False @@ -650,8 +824,8 @@ class EventPushActionsStore(SQLBaseStore): txn.execute(""" SELECT stream_ordering FROM event_push_actions WHERE stream_ordering > ? - ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000 - """, (old_rotate_stream_ordering,)) + ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? + """, (old_rotate_stream_ordering, self._rotate_count)) stream_row = txn.fetchone() if stream_row: offset_stream_ordering, = stream_row diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7002b3752e..906a405031 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,64 +13,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer, reactor +import itertools +import logging +from collections import OrderedDict, deque, namedtuple +from functools import wraps -from synapse.events import FrozenEvent, USE_FROZEN_DICTS -from synapse.events.utils import prune_event +from six import iteritems +from six.moves import range -from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred -) -from synapse.util.logutils import log_function -from synapse.util.metrics import Measure -from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError -from synapse.state import resolve_events -from synapse.util.caches.descriptors import cached -from synapse.types import get_domain_from_id +from canonicaljson import json +from prometheus_client import Counter -from canonicaljson import encode_canonical_json -from collections import deque, namedtuple, OrderedDict -from functools import wraps +from twisted.internet import defer import synapse.metrics - -import logging -import ujson as json - +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError # these are only included to make the type annotations work -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.events_worker import EventsWorkerStore +from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util.async import ObservableDeferred +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable +from synapse.util.logutils import log_function +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) +persist_event_counter = Counter("synapse_storage_events_persisted_events", "") +event_counter = Counter("synapse_storage_events_persisted_events_sep", "", + ["type", "origin_type", "origin_entity"]) -metrics = synapse.metrics.get_metrics_for(__name__) -persist_event_counter = metrics.register_counter("persisted_events") -event_counter = metrics.register_counter( - "persisted_events_sep", labels=["type", "origin_type", "origin_entity"] -) +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "") -def encode_json(json_object): - if USE_FROZEN_DICTS: - # ujson doesn't like frozen_dicts - return encode_canonical_json(json_object) - else: - return json.dumps(json_object, ensure_ascii=False) +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "") -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +def encode_json(json_object): + return frozendict_json_encoder.encode(json_object) class _EventPeristenceQueue(object): @@ -88,19 +84,29 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + Args: room_id (str): events_and_contexts (list[(EventBase, EventContext)]): backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() - deferred = ObservableDeferred(defer.Deferred()) + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, @@ -113,11 +119,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -136,18 +142,23 @@ class _EventPeristenceQueue(object): try: queue = self._get_drainining_queue(room_id) for item in queue: + # handle_queue_loop runs in the sentinel logcontext, so + # there is no need to preserve_fn when running the + # callbacks on the deferred. try: ret = yield per_item_callback(item) - item.deferred.callback(ret) - except Exception as e: - item.deferred.errback(e) + with PreserveLoggingContext(): + item.deferred.callback(ret) + except Exception: + item.deferred.errback() finally: queue = self._event_persist_queues.pop(room_id, None) if queue: self._event_persist_queues[room_id] = queue self._currently_persisting_rooms.discard(room_id) - preserve_fn(handle_queue_loop)() + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) def _get_drainining_queue(self, room_id): queue = self._event_persist_queues.setdefault(room_id, deque()) @@ -183,13 +194,12 @@ def _retry_on_integrity_error(func): return f -class EventsStore(SQLBaseStore): +class EventsStore(EventsWorkerStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - def __init__(self, hs): - super(EventsStore, self).__init__(hs) - self._clock = hs.get_clock() + def __init__(self, db_conn, hs): + super(EventsStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) @@ -220,6 +230,8 @@ class EventsStore(SQLBaseStore): self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database @@ -232,8 +244,8 @@ class EventsStore(SQLBaseStore): partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in partitioned.iteritems(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) @@ -242,7 +254,7 @@ class EventsStore(SQLBaseStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @@ -267,7 +279,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -275,10 +287,11 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -316,7 +329,7 @@ class EventsStore(SQLBaseStore): chunks = [ events_and_contexts[x:x + 100] - for x in xrange(0, len(events_and_contexts), 100) + for x in range(0, len(events_and_contexts), 100) ] for chunk in chunks: @@ -325,8 +338,23 @@ class EventsStore(SQLBaseStore): # NB: Assumes that we are only persisting events for one room # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. @@ -338,7 +366,7 @@ class EventsStore(SQLBaseStore): (event, context) ) - for room_id, ev_ctx_rm in events_by_room.iteritems(): + for room_id, ev_ctx_rm in iteritems(events_by_room): # Work out new extremities by recursively adding and removing # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( @@ -348,7 +376,8 @@ class EventsStore(SQLBaseStore): room_id, ev_ctx_rm, latest_event_ids ) - if new_latest_event_ids == set(latest_event_ids): + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -369,11 +398,63 @@ class EventsStore(SQLBaseStore): if all_single_prev_not_state: continue - state = yield self._calculate_state_delta( - room_id, ev_ctx_rm, new_latest_event_ids + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(e for e, _ in ev.prev_events) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info( + "Calculating state delta for room %s", room_id, ) - if state: - current_state_for_room[room_id] = state + with Measure( + self._clock, + "persist_events.get_new_state_after_events", + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, + "persist_events.calculate_state_delta", + ): + delta = yield self._calculate_state_delta( + room_id, current_state, + ) + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state yield self.runInteraction( "persist_events", @@ -381,10 +462,13 @@ class EventsStore(SQLBaseStore): events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, - current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) - persist_event_counter.inc_by(len(chunk)) + persist_event_counter.inc(len(chunk)) + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) for event, context in chunk: if context.app_service: origin_type = "local" @@ -396,14 +480,14 @@ class EventsStore(SQLBaseStore): origin_type = "remote" origin_entity = get_domain_from_id(event.sender) - event_counter.inc(event.type, origin_type, origin_entity) + event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, (_, _, new_state) in current_state_for_room.iteritems(): + for room_id, new_state in iteritems(current_state_for_room): self.get_current_state_ids.prefill( (room_id, ), new_state ) - for room_id, latest_event_ids in new_forward_extremeties.iteritems(): + for room_id, latest_event_ids in iteritems(new_forward_extremeties): self.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -450,183 +534,187 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): - """Calculate the new state deltas for a room. + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): + """Calculate the current state dict after adding some new events to + a room - Assumes that we are only persisting events for one room at a time. + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. Returns: - 3-tuple (to_delete, to_insert, new_state) where both are state dicts, - i.e. (type, state_key) -> event_id. `to_delete` are the entries to - first be deleted from current_state_events, `to_insert` are entries - to insert. `new_state` is the full set of state. - May return None if there are no changes to be applied. + Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. """ - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - state_groups = set() - missing_event_ids = [] - was_updated = False + + if not new_latest_event_ids: + return + + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - - # If we've already seen the state group don't bother adding - # it to the state sets again - if ctx.state_group not in state_groups: - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True - if ctx.state_group: - # Add this as a seen state group (if it has a state - # group) - state_groups.add(ctx.state_group) + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) - groups = set(event_to_groups.itervalues()) - state_groups + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) - if groups: - group_to_state = yield self._get_state_for_groups(groups) - state_sets.extend(group_to_state.itervalues()) + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - if not new_latest_event_ids: - current_state = {} - elif was_updated: - if len(state_sets) == 1: - # If there is only one state set, then we know what the current - # state is. - current_state = state_sets[0] - else: - # We work out the current state by passing the state sets to the - # state resolution algorithm. It may ask for some events, including - # the events we have yet to persist, so we need a slightly more - # complicated event lookup function than simply looking the events - # up in the db. - events_map = {ev.event_id: ev for ev, _ in events_context} - - @defer.inlineCallbacks - def get_events(ev_ids): - # We get the events by first looking at the list of events we - # are trying to persist, and then fetching the rest from the DB. - db = [] - to_return = {} - for ev_id in ev_ids: - ev = events_map.get(ev_id, None) - if ev: - to_return[ev_id] = ev - else: - db.append(ev_id) - - if db: - evs = yield self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - to_return.update(evs) - defer.returnValue(to_return) - - current_state = yield resolve_events( - state_sets, - state_map_factory=get_events, - ) - else: - return + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + defer.returnValue((None, None)) - existing_state = yield self.get_current_state_ids(room_id) + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. - existing_events = set(existing_state.itervalues()) - new_events = set(ev_id for ev_id in current_state.itervalues()) - changed_events = existing_events ^ new_events + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) - if not changed_events: - return + delta_ids = state_group_deltas.get( + (old_state_group, new_state_group,), None + ) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + defer.returnValue((new_state, delta_ids)) + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + defer.returnValue((state_groups_map[new_state_groups.pop()], None)) + + # Ok, we need to defer to the state handler to resolve our state sets. + + def get_events(ev_ids): + return self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) - to_delete = { - key: ev_id for key, ev_id in existing_state.iteritems() - if ev_id in changed_events + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups } - events_to_insert = (new_events - existing_events) - to_insert = { - key: ev_id for key, ev_id in current_state.iteritems() - if ev_id in events_to_insert - } - - defer.returnValue((to_delete, to_insert, current_state)) - - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. - - Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. - Returns: - Deferred : A FrozenEvent. - """ - events = yield self._get_events( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + events_map = {ev.event_id: ev for ev, _ in events_context} + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, state_groups, events_map, get_events ) - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - defer.returnValue(events[0] if events else None) + defer.returnValue((res.state, None)) @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - """Get events from the database + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + Assumes that we are only persisting events for one room at a time. Returns: - Deferred : Dict from event_id to event. + tuple[list, dict] (to_delete, to_insert): where to_delete are the + type/state_keys to remove from current_state_events and `to_insert` + are the updates to current_state_events. """ - events = yield self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + existing_state = yield self.get_current_state_ids(room_id) + + to_delete = [ + key for key in existing_state + if key not in current_state + ] + + to_insert = { + key: ev_id for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } - defer.returnValue({e.event_id: e for e in events}) + defer.returnValue((to_delete, to_insert)) @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, current_state_for_room={}, + delete_existing=False, state_delta_for_room={}, new_forward_extremeties={}): """Insert some number of room events into the necessary database tables. @@ -642,19 +730,21 @@ class EventsStore(SQLBaseStore): delete_existing (bool): True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - current_state_for_room (dict[str, (list[str], list[str])]): + state_delta_for_room (dict[str, (list, dict)]): The current-state delta for each room. For each room, a tuple - (to_delete, to_insert), being a list of event ids to be removed - from the current state, and a list of event ids to be added to + (to_delete, to_insert), being a list of type/state keys to be + removed from the current state, and a state set to be added to the current state. new_forward_extremeties (dict[str, list[str]]): The new forward extremities for each room. For each room, a list of the event ids which are the forward extremities. """ + all_events_and_contexts = events_and_contexts + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - self._update_current_state_txn(txn, current_state_for_room, max_stream_order) + self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) self._update_forward_extremities_txn( txn, @@ -698,9 +788,8 @@ class EventsStore(SQLBaseStore): events_and_contexts=events_and_contexts, ) - # Insert into the state_groups, state_groups_state, and - # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, events_and_contexts) + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. @@ -715,15 +804,53 @@ class EventsStore(SQLBaseStore): self._update_metadata_tables_txn( txn, events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, backfilled=backfilled, ) def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): - for room_id, current_state_tuple in state_delta_by_room.iteritems(): - to_delete, to_insert, _ = current_state_tuple + for room_id, current_state_tuple in iteritems(state_delta_by_room): + to_delete, to_insert = current_state_tuple + + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ( + SELECT event_id FROM current_state_events + WHERE room_id = ? AND type = ? AND state_key = ? + ) + """ + txn.executemany(sql, ( + ( + max_stream_order, room_id, etype, state_key, None, + room_id, etype, state_key, + ) + for etype, state_key in to_delete + # We sanity check that we're deleting rather than updating + if (etype, state_key) not in to_insert + )) + txn.executemany(sql, ( + ( + max_stream_order, room_id, etype, state_key, ev_id, + room_id, etype, state_key, + ) + for (etype, state_key), ev_id in iteritems(to_insert) + )) + + # Now we actually update the current_state_events table + txn.executemany( - "DELETE FROM current_state_events WHERE event_id = ?", - [(ev_id,) for ev_id in to_delete.itervalues()], + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), ) self._simple_insert_many_txn( @@ -736,30 +863,12 @@ class EventsStore(SQLBaseStore): "type": key[0], "state_key": key[1], } - for key, ev_id in to_insert.iteritems() + for key, ev_id in iteritems(to_insert) ], ) - state_deltas = {key: None for key in to_delete} - state_deltas.update(to_insert) - - self._simple_insert_many_txn( - txn, - table="current_state_delta_stream", - values=[ - { - "stream_id": max_stream_order, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": ev_id, - "prev_event_id": to_delete.get(key, None), - } - for key, ev_id in state_deltas.iteritems() - ] - ) - - self._curr_state_delta_stream_cache.entity_has_changed( + txn.call_after( + self._curr_state_delta_stream_cache.entity_has_changed, room_id, max_stream_order, ) @@ -771,19 +880,23 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - state_key for ev_type, state_key in state_deltas + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member ) for member in members_changed: self._invalidate_cache_and_stream( - txn, self.get_rooms_for_user, (member,) + txn, self.get_rooms_for_user_with_stream_ordering, (member,) ) for host in set(get_domain_from_id(u) for u in members_changed): self._invalidate_cache_and_stream( txn, self.is_host_joined, (room_id, host) ) + self._invalidate_cache_and_stream( + txn, self.was_host_joined, (room_id, host) + ) self._invalidate_cache_and_stream( txn, self.get_users_in_room, (room_id,) @@ -795,7 +908,7 @@ class EventsStore(SQLBaseStore): def _update_forward_extremities_txn(self, txn, new_forward_extremities, max_stream_order): - for room_id, new_extrem in new_forward_extremities.iteritems(): + for room_id, new_extrem in iteritems(new_forward_extremities): self._simple_delete_txn( txn, table="event_forward_extremities", @@ -813,7 +926,7 @@ class EventsStore(SQLBaseStore): "event_id": ev_id, "room_id": room_id, } - for room_id, new_extrem in new_forward_extremities.iteritems() + for room_id, new_extrem in iteritems(new_forward_extremities) for ev_id in new_extrem ], ) @@ -830,7 +943,7 @@ class EventsStore(SQLBaseStore): "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in new_forward_extremities.iteritems() + for room_id, new_extrem in iteritems(new_forward_extremities) for event_id in new_extrem ] ) @@ -858,7 +971,7 @@ class EventsStore(SQLBaseStore): new_events_and_contexts[event.event_id] = (event, context) else: new_events_and_contexts[event.event_id] = (event, context) - return new_events_and_contexts.values() + return list(new_events_and_contexts.values()) def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): """Update min_depth for each room @@ -884,7 +997,7 @@ class EventsStore(SQLBaseStore): event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in depth_updates.iteritems(): + for room_id, depth in iteritems(depth_updates): self._update_min_depth_for_room_txn(txn, room_id, depth) def _update_outliers_txn(self, txn, events_and_contexts): @@ -932,10 +1045,9 @@ class EventsStore(SQLBaseStore): # an outlier in the database. We now have some state at that # so we need to update the state_groups table with that state. - # insert into the state_group, state_groups_state and - # event_to_state_groups tables. + # insert into event_to_state_groups. try: - self._store_mult_state_groups_txn(txn, ((event, context),)) + self._store_event_state_mappings_txn(txn, ((event, context),)) except Exception: logger.exception("") raise @@ -1001,7 +1113,6 @@ class EventsStore(SQLBaseStore): "event_edge_hashes", "event_edges", "event_forward_extremities", - "event_push_actions", "event_reference_hashes", "event_search", "event_signatures", @@ -1021,6 +1132,14 @@ class EventsStore(SQLBaseStore): [(ev.event_id,) for ev, _ in events_and_contexts] ) + for table in ( + "event_push_actions", + ): + txn.executemany( + "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts] + ) + def _store_event_txn(self, txn, events_and_contexts): """Insert new events into the event and event_json tables @@ -1110,27 +1229,33 @@ class EventsStore(SQLBaseStore): ec for ec in events_and_contexts if ec[0] not in to_remove ] - def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): + def _update_metadata_tables_txn(self, txn, events_and_contexts, + all_events_and_contexts, backfilled): """Update all the miscellaneous tables for new events Args: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. backfilled (bool): True if the events were backfilled """ + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + if not events_and_contexts: # nothing to do here return for event, context in events_and_contexts: - # Insert all the push actions into the event_push_actions table. - if context.push_actions: - self._set_push_actions_for_event_and_users_txn( - txn, event, context.push_actions - ) - if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. @@ -1263,7 +1388,7 @@ class EventsStore(SQLBaseStore): " WHERE e.event_id IN (%s)" ) % (",".join(["?"] * len(ev_map)),) - txn.execute(sql, ev_map.keys()) + txn.execute(sql, list(ev_map)) rows = self.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] @@ -1302,13 +1427,49 @@ class EventsStore(SQLBaseStore): defer.returnValue(set(r["event_id"] for r in rows)) - def have_events(self, event_ids): + @defer.inlineCallbacks + def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. + Args: + event_ids (iterable[str]): + Returns: - dict: Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps to - None. + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. """ if not event_ids: return defer.succeed({}) @@ -1330,295 +1491,7 @@ class EventsStore(SQLBaseStore): return res - return self.runInteraction( - "have_events", f, - ) - - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - defer.returnValue([]) - - event_id_list = event_ids - event_ids = set(event_ids) - - event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - allow_rejected=allow_rejected, - ) - - event_entry_map.update(missing_events) - - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches - - Args: - events (list(str)): list of event_ids to fetch - allow_rejected (bool): Whether to teturn events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics - - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` - """ - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get( - (event_id,), None, - update_metrics=update_metrics, - ) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - event_list = [] - i = 0 - while True: - try: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - event_id_lists = zip(*event_list)[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] - - rows = self._new_transaction( - conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids - ) - - row_dict = { - r["event_id"]: r - for r in rows - } - - # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) - except: - logger.exception("Failed to callback") - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(e) - - if event_list: - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) - - @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - """ - if not events: - defer.returnValue({}) - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - with PreserveLoggingContext(): - self.runWithConnection( - self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self._get_event_from_row)( - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - ) - for row in rows - ], - consumeErrors=True - )) - - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) - - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i * N:(i + 1) * N] - if not evs: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " e.internal_metadata," - " e.json," - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - - return rows - - @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - rejected_reason=None): - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - - original_ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - defer.returnValue(cache_entry) + return self.runInteraction("get_rejection_reasons", f) @defer.inlineCallbacks def count_daily_messages(self): @@ -1778,7 +1651,7 @@ class EventsStore(SQLBaseStore): chunks = [ event_ids[i:i + 100] - for i in xrange(0, len(event_ids), 100) + for i in range(0, len(event_ids), 100) ] for chunk in chunks: ev_rows = self._simple_select_many_txn( @@ -2005,15 +1878,32 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def delete_old_state(self, room_id, topological_ordering): + def purge_history( + self, room_id, token, delete_local_events, + ): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + return self.runInteraction( - "delete_old_state", - self._delete_old_state_txn, room_id, topological_ordering + "purge_history", + self._purge_history_txn, room_id, token, + delete_local_events, ) - def _delete_old_state_txn(self, txn, room_id, topological_ordering): - """Deletes old room state - """ + def _purge_history_txn( + self, txn, room_id, token_str, delete_local_events, + ): + token = RoomStreamToken.parse(token_str) # Tables that should be pruned: # event_auth @@ -2035,6 +1925,37 @@ class EventsStore(SQLBaseStore): # state_groups # state_groups_state + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute( + "CREATE INDEX events_to_purge_id" + " ON events_to_purge(event_id)", + ) + # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -2047,7 +1968,7 @@ class EventsStore(SQLBaseStore): rows = txn.fetchall() max_depth = max(row[0] for row in rows) - if max_depth <= topological_ordering: + if max_depth <= token.topological: # We need to ensure we don't delete all the events from the datanase # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) @@ -2055,42 +1976,48 @@ class EventsStore(SQLBaseStore): 400, "topological_ordering is greater than forward extremeties" ) - logger.debug("[purge] looking for events to delete") + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + should_delete_params += ("%:" + self.hs.hostname, ) + + should_delete_params += (room_id, token.topological) txn.execute( - "SELECT event_id, state_key FROM events" - " LEFT JOIN state_events USING (room_id, event_id)" - " WHERE room_id = ? AND topological_ordering < ?", - (room_id, topological_ordering,) + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE e.room_id = ? AND topological_ordering < ?" % ( + should_delete_expr, + ), + should_delete_params, + ) + txn.execute( + "SELECT event_id, should_delete FROM events_to_purge" ) event_rows = txn.fetchall() - - to_delete = [ - (event_id,) for event_id, state_key in event_rows - if state_key is None and not self.hs.is_mine_id(event_id) - ] logger.info( - "[purge] found %i events before cutoff, of which %i are remote" - " non-state events to delete", len(event_rows), len(to_delete)) - - for event_id, state_key in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), sum(1 for e in event_rows if e[1]), + ) - logger.debug("[purge] Finding new backward extremities") + logger.info("[purge] Finding new backward extremities") # We calculate the new entries for the backward extremeties by finding - # all events that point to events that are to be purged + # events to be purged that are pointed to by events we're not going to + # purge. txn.execute( - "SELECT DISTINCT e.event_id FROM events as e" - " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id" - " INNER JOIN events as e2 ON e2.event_id = ed.event_id" - " WHERE e.room_id = ? AND e.topological_ordering < ?" - " AND e2.topological_ordering >= ?", - (room_id, topological_ordering, topological_ordering) + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL", ) new_backwards_extrems = txn.fetchall() - logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems) + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) txn.execute( "DELETE FROM event_backward_extremities WHERE room_id = ?", @@ -2106,34 +2033,39 @@ class EventsStore(SQLBaseStore): ] ) - logger.debug("[purge] finding redundant state groups") + logger.info("[purge] finding redundant state groups") # Get all state groups that are only referenced by events that are # to be deleted. - txn.execute( - "SELECT state_group FROM event_to_state_groups" - " INNER JOIN events USING (event_id)" - " WHERE state_group IN (" - " SELECT DISTINCT state_group FROM events" - " INNER JOIN event_to_state_groups USING (event_id)" - " WHERE room_id = ? AND topological_ordering < ?" - " )" - " GROUP BY state_group HAVING MAX(topological_ordering) < ?", - (room_id, topological_ordering, topological_ordering) - ) + # This works by first getting state groups that we may want to delete, + # joining against event_to_state_groups to get events that use that + # state group, then left joining against events_to_purge again. Any + # state group where the left join produce *no nulls* are referenced + # only by events that are going to be purged. + txn.execute(""" + SELECT state_group FROM + ( + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + ) AS sp + INNER JOIN event_to_state_groups USING (state_group) + LEFT JOIN events_to_purge AS ep USING (event_id) + GROUP BY state_group + HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0 + """) state_rows = txn.fetchall() - logger.debug("[purge] found %i redundant state groups", len(state_rows)) + logger.info("[purge] found %i redundant state groups", len(state_rows)) # make a set of the redundant state groups, so that we can look them up # efficiently state_groups_to_delete = set([sg for sg, in state_rows]) # Now we get all the state groups that rely on these state groups - logger.debug("[purge] finding state groups which depend on redundant" - " state groups") + logger.info("[purge] finding state groups which depend on redundant" + " state groups") remaining_state_groups = [] - for i in xrange(0, len(state_rows), 100): + for i in range(0, len(state_rows), 100): chunk = [sg for sg, in state_rows[i:i + 100]] # look for state groups whose prev_state_group is one we are about # to delete @@ -2156,7 +2088,7 @@ class EventsStore(SQLBaseStore): # Now we turn the state groups that reference to-be-deleted state # groups to non delta versions. for sg in remaining_state_groups: - logger.debug("[purge] de-delta-ing remaining state group %s", sg) + logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( txn, [sg], types=None ) @@ -2189,11 +2121,11 @@ class EventsStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in curr_state.iteritems() + for key, state_id in iteritems(curr_state) ], ) - logger.debug("[purge] removing redundant state groups") + logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows @@ -2203,18 +2135,15 @@ class EventsStore(SQLBaseStore): state_rows ) - # Delete all non-state - logger.debug("[purge] removing events from event_to_state_groups") - txn.executemany( - "DELETE FROM event_to_state_groups WHERE event_id = ?", - [(event_id,) for event_id, _ in event_rows] - ) - - logger.debug("[purge] updating room_depth") + logger.info("[purge] removing events from event_to_state_groups") txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (topological_ordering, room_id,) + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, ( + event_id, + )) # Delete all remote non-state events for table in ( @@ -2226,28 +2155,75 @@ class EventsStore(SQLBaseStore): "event_edge_hashes", "event_edges", "event_forward_extremities", - "event_push_actions", "event_reference_hashes", "event_search", "event_signatures", "rejections", ): - logger.debug("[purge] removing remote non-state events from %s", table) + logger.info("[purge] removing events from %s", table) - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - to_delete + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ( + "event_push_actions", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id, ) ) # Mark all state and own events as outliers - logger.debug("[purge] marking remaining events as outliers") - txn.executemany( + logger.info("[purge] marking remaining events as outliers") + txn.execute( "UPDATE events SET outlier = ?" - " WHERE event_id = ?", - [ - (True, event_id,) for event_id, state_key in event_rows - if state_key is not None or self.hs.is_mine_id(event_id) - ] + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute(""" + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, (room_id,)) + min_depth, = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id,) + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute( + "DROP TABLE events_to_purge" ) logger.info("[purge] done") @@ -2260,7 +2236,7 @@ class EventsStore(SQLBaseStore): to_2, so_2 = yield self._get_event_ordering(event_id2) defer.returnValue((to_1, so_1) > (to_2, so_2)) - @defer.inlineCallbacks + @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): res = yield self._simple_select_one( table="events", diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py new file mode 100644 index 0000000000..f28239a808 --- /dev/null +++ b/synapse/storage/events_worker.py @@ -0,0 +1,436 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections import namedtuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events.utils import prune_event +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) +from synapse.util.metrics import Measure + +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={ + "event_id": event_id, + }, + retcol="received_ts", + desc="get_received_ts", + ) + + @defer.inlineCallbacks + def get_event(self, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False, + allow_none=False): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw an exception. + + Returns: + Deferred : A FrozenEvent. + """ + events = yield self._get_events( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + if not events and not allow_none: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue(events[0] if events else None) + + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_id_list = event_ids + event_ids = set(event_ids) + + event_entry_map = self._get_events_from_cache( + event_ids, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + log_ctx = LoggingContext.current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) + + event_entry_map.update(missing_events) + + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + self._fetch_event_list(conn, event_list) + + def _fetch_event_list(self, conn, event_list): + """Handle a load of requests from the _event_fetch_list queue + + Args: + conn (twisted.enterprise.adbapi.Connection): database connection + + event_list (list[Tuple[list[str], Deferred]]): + The fetch requests. Each entry consists of a list of event + ids to be fetched, and a deferred to be completed once the + events have been fetched. + + """ + with Measure(self._clock, "_fetch_event_list"): + try: + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], [], + self._fetch_event_rows, event_ids, + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) + except Exception: + logger.exception("Failed to callback") + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(e) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", + self.runWithConnection, + self._do_fetch, + ) + + logger.debug("Loading %d events", len(events)) + with PreserveLoggingContext(): + rows = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield make_deferred_yieldable(defer.gatherResults( + [ + run_in_background( + self._get_event_from_row, + row["internal_metadata"], row["json"], row["redacts"], + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + )) + + defer.returnValue({ + e.event.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) // N): + evs = events[i * N:(i + 1) * N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + rejected_reason=None): + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) + + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) + + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) + + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because + + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, + ) + + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) + + defer.returnValue(cache_entry) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 78b1e30945..2d5896c5b4 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer -from ._base import SQLBaseStore -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from canonicaljson import encode_canonical_json -import simplejson as json +from ._base import SQLBaseStore class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(str(def_json).decode("utf-8"))) + defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py new file mode 100644 index 0000000000..592d1b4c2a --- /dev/null +++ b/synapse/storage/group_server.py @@ -0,0 +1,1252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError + +from ._base import SQLBaseStore + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerStore(SQLBaseStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues={ + "join_policy": join_policy, + }, + desc="set_group_join_policy", + ) + + def get_group(self, group_id): + return self._simple_select_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + retcols=( + "name", "short_description", "long_description", + "avatar_url", "is_public", "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin",), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + }, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_rooms", + keyvalues=keyvalues, + retcols=("room_id", "is_public",), + desc="get_rooms_in_group", + ) + + def get_rooms_for_summary_by_category(self, group_id, include_private=False): + """Get the rooms and categories that should be included in a summary request + + Returns ([rooms], [categories]) + """ + def _get_rooms_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + return self.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.runInteraction( + "add_room_to_summary", self._add_room_to_summary_txn, + group_id, room_id, category_id, order, is_public, + ) + + def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, + is_public): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, (group_id, category_id, group_id, category_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self._simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self._simple_select_list( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + }, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + defer.returnValue({ + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self._simple_select_one( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + defer.returnValue(category) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self._simple_delete( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + desc="remove_group_category", + ) + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self._simple_select_list( + table="group_roles", + keyvalues={ + "group_id": group_id, + }, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + defer.returnValue({ + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self._simple_select_one( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + defer.returnValue(role) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self._simple_delete( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.runInteraction( + "add_user_to_summary", self._add_user_to_summary_txn, + group_id, user_id, role_id, order, is_public, + ) + + def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, + is_public): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, (group_id, role_id, group_id, role_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + "role_id": role_id, + }, + retcols=("user_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self._simple_delete( + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + desc="remove_user_from_summary", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + def _get_users_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + return self.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self._simple_insert( + table="group_invites", + values={ + "group_id": group_id, + "user_id": user_id, + }, + desc="add_group_invite", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self._simple_select_one_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + def _get_users_membership_in_group_txn(txn): + row = self._simple_select_one_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self._simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + + if row: + return { + "membership": "invite", + } + + return {} + + return self.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn, + ) + + def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, + local_attestation=None, remote_attestation=None): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + def _add_user_to_group_txn(txn): + self._simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.runInteraction( + "add_user_to_group", _add_user_to_group_txn + ) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) + + def add_room_to_group(self, group_id, room_id, is_public): + return self._simple_insert( + table="group_rooms", + values={ + "group_id": group_id, + "room_id": room_id, + "is_public": is_public, + }, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self._simple_update( + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + updatevalues={ + "is_public": is_public, + }, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + + self._simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + return self.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn, + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + "is_publicised": True, + }, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self._simple_update_one( + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "is_publicised": publicise, + }, + desc="update_group_publicity" + ) + + @defer.inlineCallbacks + def register_user_group_membership(self, group_id, user_id, membership, + is_admin=False, content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self._simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self._simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps({"membership": membership, "content": content}), + } + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + } + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + } + ) + else: + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, next_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, name, avatar_url, short_description, + long_description,): + yield self._simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile,): + yield self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues=profile, + desc="update_group_profile", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.cursor_to_dict(txn) + return self.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self._simple_update_one( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + }, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self._simple_update_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation) + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self._simple_select_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + defer.returnValue(json.loads(row["attestation_json"])) + + defer.returnValue(None) + + def get_joined_groups(self, user_id): + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + }, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token,)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + return self.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn, + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token, + ) + if not has_changed: + return [] + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token,)) + return [{ + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } for group_id, membership, gtype, content_json in txn] + return self.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn, + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token, + ) + if not has_changed: + return [] + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit,)) + return [( + stream_id, + group_id, + user_id, + gtype, + json.loads(content_json), + ) for stream_id, group_id, user_id, gtype, content_json in txn] + return self.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn, + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 3b5e0a4fb9..f547977600 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,19 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +import hashlib +import logging -from twisted.internet import defer +import six -import OpenSSL from signedjson.key import decode_verify_key_bytes -import hashlib -import logging +import OpenSSL +from twisted.internet import defer + +from synapse.util.caches.descriptors import cachedInlineCallbacks + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview + class KeyStore(SQLBaseStore): """Persistence for signature verification keys and tls X.509 certificates @@ -72,7 +82,7 @@ class KeyStore(SQLBaseStore): values={ "from_server": from_server, "ts_added_ms": time_now_ms, - "tls_certificate": buffer(tls_certificate_bytes), + "tls_certificate": db_binary_type(tls_certificate_bytes), }, desc="store_server_certificate", ) @@ -92,7 +102,7 @@ class KeyStore(SQLBaseStore): if verify_key_bytes: defer.returnValue(decode_verify_key_bytes( - key_id, str(verify_key_bytes) + key_id, bytes(verify_key_bytes) )) @defer.inlineCallbacks @@ -113,30 +123,37 @@ class KeyStore(SQLBaseStore): keys[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. - key_id (str): The version of the key for the server. from_server (str): Where the verification key was looked up - ts_now_ms (int): The time now in milliseconds - verification_key (VerifyKey): The NACL verify key. + time_now_ms (int): The time now in milliseconds + verify_key (nacl.signing.VerifyKey): The NACL verify key. """ - yield self._simple_upsert( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": "%s:%s" % (verify_key.alg, verify_key.version), - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": buffer(verify_key.encode()), - }, - desc="store_server_verify_key", - ) + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + def _txn(txn): + self._simple_upsert_txn( + txn, + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + }, + values={ + "from_server": from_server, + "ts_added_ms": time_now_ms, + "verify_key": db_binary_type(verify_key.encode()), + }, + ) + txn.call_after( + self._get_server_verify_key.invalidate, + (server_name, key_id) + ) + + return self.runInteraction("store_server_verify_key", _txn) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): @@ -165,7 +182,7 @@ class KeyStore(SQLBaseStore): "from_server": from_server, "ts_added_ms": ts_now_ms, "ts_valid_until_ms": ts_expires_ms, - "key_json": buffer(key_json_bytes), + "key_json": db_binary_type(key_json_bytes), }, desc="store_server_keys_json", ) diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 82bb61b811..e6cdbb0545 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -12,15 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore -from ._base import SQLBaseStore - -class MediaRepositoryStore(SQLBaseStore): +class MediaRepositoryStore(BackgroundUpdateStore): """Persistence for attachments and avatars""" - def get_default_thumbnails(self, top_level_type, sub_type): - return [] + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name='local_media_repository_url_idx', + index_name='local_media_repository_url_idx', + table='local_media_repository', + columns=['created_ts'], + where_clause='url_cache IS NOT NULL', + ) def get_local_media(self, media_id): """Get the metadata for a local piece of media @@ -62,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" @@ -74,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore): # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" @@ -86,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore): return None return dict(zip(( - 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' ), row)) return self.runInteraction( "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, + def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, download_ts): return self._simple_insert( "local_media_repository_url_cache", @@ -101,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore): "url": url, "response_code": response_code, "etag": etag, - "expires": expires, + "expires_ts": expires_ts, "og": og, "media_id": media_id, "download_ts": download_ts, @@ -166,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, origin_id_tuples, time_ts): + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ def update_cache_txn(txn): sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" @@ -174,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore): ) txn.executemany(sql, ( - (time_ts, media_origin, media_id) - for media_origin, media_id in origin_id_tuples + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + )) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ( + (time_ms, media_id) + for media_id in local_media )) return self.runInteraction("update_cached_last_access_time", update_cache_txn) @@ -238,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore): }, ) return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = ( + "DELETE FROM local_media_repository_url_cache" + " WHERE media_id = ?" + ) + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn, + ) + + def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = ( + "DELETE FROM local_media_repository" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = ( + "DELETE FROM local_media_repository_thumbnails" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn, + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 72b670b83b..b290f834b3 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,13 +20,12 @@ import logging import os import re - logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 43 +SCHEMA_VERSION = 50 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,9 +71,13 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() - except: + except Exception: db_conn.rollback() raise @@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 9e9d3c2591..a0c7a0dc87 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from collections import namedtuple + +from twisted.internet import defer + from synapse.api.constants import PresenceState +from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from collections import namedtuple -from twisted.internet import defer +from ._base import SQLBaseStore class UserPresenceState(namedtuple("UserPresenceState", @@ -115,11 +118,7 @@ class PresenceStore(SQLBaseStore): " AND user_id IN (%s)" ) - batches = ( - presence_states[i:i + 50] - for i in xrange(0, len(presence_states), 50) - ) - for states in batches: + for states in batch_iter(presence_states, 50): args = [stream_id] args.extend(s.user_id for s in states) txn.execute( diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..60295da254 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,15 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.storage.roommember import ProfileInfo + from ._base import SQLBaseStore -class ProfileStore(SQLBaseStore): - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", +class ProfileWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) ) def get_profile_displayname(self, user_localpart): @@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore): desc="get_profile_displayname", ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - def get_profile_avatar_url(self, user_localpart): return self._simple_select_one_onecol( table="profiles", @@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore): desc="get_profile_avatar_url", ) + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url",), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + +class ProfileStore(ProfileWorkerStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + desc="create_profile", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + def set_profile_avatar_url(self, user_localpart, new_avatar_url): return self._simple_update_one( table="profiles", @@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore): updatevalues={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) + + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8758b1c0c7..6a5028961d 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.push.baserules import list_with_base_rules -from synapse.api.constants import EventTypes +import abc +import logging + +from canonicaljson import json + from twisted.internet import defer -import logging -import simplejson as json +from synapse.push.baserules import list_with_base_rules +from synapse.storage.appservice import ApplicationServiceWorkerStore +from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -48,7 +57,43 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): +class PushRulesWorkerStore(ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( @@ -89,6 +134,22 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -124,6 +185,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) + @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: @@ -133,9 +195,11 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room( - event.room_id, state_group, context.current_state_ids, event=event + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event ) + defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, @@ -185,18 +249,6 @@ class PushRuleStore(SQLBaseStore): if uid in local_users_in_room: user_ids.add(uid) - forgotten = yield self.who_forgot_in_room( - event.room_id, on_invalidate=cache_context.invalidate, - ) - - for row in forgotten: - user_id = row["user_id"] - event_id = row["event_id"] - - mem_id = current_state_ids.get((EventTypes.Member, user_id), None) - if event_id == mem_id: - user_ids.discard(user_id) - rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) @@ -228,6 +280,8 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) + +class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, @@ -526,21 +580,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 34d2f82b7f..8443bd4c1b 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer +import logging +import types + +from canonicaljson import encode_canonical_json, json -from canonicaljson import encode_canonical_json +from twisted.internet import defer from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -import logging -import simplejson as json -import types +from ._base import SQLBaseStore logger = logging.getLogger(__name__) -class PusherStore(SQLBaseStore): +class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: dataJson = r['data'] @@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed(([], [])) @@ -198,56 +196,74 @@ class PusherStore(SQLBaseStore): defer.returnValue(result) + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - def f(txn): - newly_inserted = self._simple_upsert_txn( - txn, - "pushers", - { - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, - { - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": encode_canonical_json(data), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - ) - if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + newly_inserted = yield self._simple_upsert( + table="pushers", + keyvalues={ + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) - yield self.runInteraction("add_pusher", f) + if newly_inserted: + yield self.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, (user_id,) + ) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) self._simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} ) - self._simple_upsert_txn( + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( txn, - "deleted_pushers", - {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id": stream_id}, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, ) with self._pushers_id_gen.get_next() as stream_id: @@ -310,9 +326,12 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry yield self._simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, - desc="set_throttle_params" + desc="set_throttle_params", + lock=False, ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f42b8014c7..0ac665e967 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,52 +14,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached -from synapse.util.caches.stream_change_cache import StreamChangeCache +import abc +import logging + +from canonicaljson import json from twisted.internet import defer -import logging -import ujson as json +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache +from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): - def __init__(self, hs): - super(ReceiptsStore, self).__init__(hs) +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False, - ) - - if res: - if isinstance(res, defer.Deferred) and res.called: - res = res.result - if user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore): """ room_ids = set(room_ids) - if from_key: + if from_key is not None: + # Only ask the database about rooms where there have been new + # receipts added since `from_key` room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) @@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore): from the start. Returns: - list: A list of receipts. + Deferred[list]: A list of receipts. + """ + if from_key is not None: + # Check the cache first to see if any new receipts have been added + # since`from_key`. If not we can no-op. + if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + defer.succeed([]) + + return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + + @cachedInlineCallbacks(num_args=3, tree=True) + def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """See get_linearized_receipts_for_room """ def f(txn): if from_key: @@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cached_method_name="get_linearized_receipts_for_room", + @cachedList(cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: @@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore): } defer.returnValue(results) + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns either an ObservableDeferred or the raw result + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + # first handle the Deferred case + if isinstance(res, defer.Deferred): + if res.called: + res = res.result + else: + res = None + + if res and user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + stream_ordering = int(res["stream_ordering"]) if res else None + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + if stream_ordering is not None: + sql = ( + "SELECT stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + txn.execute(sql, (room_id, receipt_type, user_id)) + + for so, eid in txn: + if int(so) >= stream_ordering: + logger.debug( + "Ignoring new receipt for %s in favour of existing " + "one for later event %s", + event_id, eid, + ) + return False + txn.call_after( self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) @@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) + txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( self._receipts_stream_cache.entity_has_changed, @@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore): (user_id, room_id, receipt_type) ) - res = self._simple_select_one_txn( - txn, - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True - ) - - topological_ordering = int(res["topological_ordering"]) if res else None - stream_ordering = int(res["stream_ordering"]) if res else None - - # We don't want to clobber receipts for more recent events, so we - # have to compare orderings of existing receipts - sql = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" - ) - - txn.execute(sql, (room_id, receipt_type, user_id)) - - if topological_ordering: - for to, so, _ in txn: - if int(to) > topological_ordering: - return False - elif int(to) == topological_ordering and int(so) >= stream_ordering: - return False - self._simple_delete_txn( txn, table="receipts_linearized", @@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore): } ) - if receipt_type == "m.read" and topological_ordering: + if receipt_type == "m.read" and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, - topological_ordering=topological_ordering, stream_ordering=stream_ordering, ) @@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore): self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) + txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, @@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 20acd58fcf..07333f777d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -15,17 +15,83 @@ import re +from six.moves import range + from twisted.internet import defer -from synapse.api.errors import StoreError, Codes +from synapse.api.errors import Codes, StoreError from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(background_updates.BackgroundUpdateStore): +class RegistrationWorkerStore(SQLBaseStore): + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=[ + "name", "password_hash", "is_guest", + "consent_version", "consent_server_notice_sent", + "appservice_id", + ], + allow_none=True, + desc="get_user_by_id", + ) + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. + """ + return self.runInteraction( + "get_user_by_access_token", + self._query_for_auth, + token + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + defer.returnValue(res if res else False) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + +class RegistrationStore(RegistrationWorkerStore, + background_updates.BackgroundUpdateStore): - def __init__(self, hs): - super(RegistrationStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() @@ -37,12 +103,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], + "users_creation_ts", + index_name="users_creation_ts", + table="users", + columns=["creation_ts"], ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.register_noop_background_update("refresh_tokens_device_index") + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -177,9 +248,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -187,18 +260,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={ - "name": user_id, - }, - retcols=["name", "password_hash", "is_guest"], - allow_none=True, - desc="get_user_by_id", - ) - def get_users_by_id_case_insensitive(self, user_id): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. @@ -236,12 +297,57 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): "user_set_password_hash", user_set_password_hash_txn ) - @defer.inlineCallbacks + def user_set_consent_version(self, user_id, consent_version): + """Updates the user table to record privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy the user has consented + to + + Raises: + StoreError(404) if user not found + """ + def f(txn): + self._simple_update_one_txn( + txn, + table='users', + keyvalues={'name': user_id, }, + updatevalues={'consent_version': consent_version, }, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + return self.runInteraction("user_set_consent_version", f) + + def user_set_consent_server_notice_sent(self, user_id, consent_version): + """Updates the user table to record that we have sent the user a server + notice about privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy we have notified the + user about + + Raises: + StoreError(404) if user not found + """ + def f(txn): + self._simple_update_one_txn( + txn, + table='users', + keyvalues={'name': user_id, }, + updatevalues={'consent_server_notice_sent': consent_version, }, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + return self.runInteraction("user_set_consent_server_notice_sent", f) + def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +356,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens """ def f(txn): keyvalues = { @@ -262,13 +367,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +375,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for row in rows: + for token, _, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,7 +390,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) - yield self.runInteraction( + return tokens_and_devices + + return self.runInteraction( "user_delete_access_tokens", f, ) @@ -312,34 +412,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): return self.runInteraction("delete_access_token", f) - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. - """ - return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - defer.returnValue(res if res else False) - @cachedInlineCallbacks() def is_guest(self, user_id): res = yield self._simple_select_one_onecol( @@ -352,22 +424,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(res if res else False) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", { @@ -404,15 +460,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(ret['user_id']) defer.returnValue(None) - def user_delete_threepids(self, user_id): - return self._simple_delete( - "user_threepids", - keyvalues={ - "user_id": user_id, - }, - desc="user_delete_threepids", - ) - def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( "user_threepids", @@ -437,6 +484,35 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ret = yield self.runInteraction("count_users", _count_users) defer.returnValue(ret) + def count_daily_user_type(self): + """ + Counts 1) native non guest users + 2) native guests users + 3) bridged users + who registered on the homeserver in the past 24 hours + """ + def _count_daily_user_type(txn): + yesterday = int(self._clock.time()) - (60 * 60 * 24) + + sql = """ + SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT + CASE + WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' + WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' + WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM users + WHERE creation_ts > ? + ) AS t GROUP BY user_type + """ + results = {'native': 0, 'guest': 0, 'bridged': 0} + txn.execute(sql, (yesterday,)) + for row in txn: + results[row[0]] = row[1] + return results + return self.runInteraction("count_daily_user_type", _count_daily_user_type) + @defer.inlineCallbacks def count_nonbridged_users(self): def _count_users(txn): @@ -464,18 +540,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): """ def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - rows = self.cursor_to_dict(txn) regex = re.compile("^@(\d+):") found = set() - for r in rows: - user_id = r["name"] + for user_id, in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) - for i in xrange(len(found) + 1): + for i in range(len(found) + 1): if i not in found: return i @@ -530,3 +604,44 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): except self.database_engine.module.IntegrityError: ret = yield self.get_3pid_guest_access_token(medium, address) defer.returnValue(ret) + + def add_user_pending_deactivation(self, user_id): + """ + Adds a user to the table of users who need to be parted from all the rooms they're + in + """ + return self._simple_insert( + "users_pending_deactivation", + values={ + "user_id": user_id, + }, + desc="add_user_pending_deactivation", + ) + + def del_user_pending_deactivation(self, user_id): + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + return self._simple_delete( + "users_pending_deactivation", + keyvalues={ + "user_id": user_id, + }, + desc="del_user_pending_deactivation", + ) + + def get_user_pending_deactivation(self): + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return self._simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py index 40acb5c4ed..880f047adb 100644 --- a/synapse/storage/rejections.py +++ b/synapse/storage/rejections.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore - import logging +from ._base import SQLBaseStore + logger = logging.getLogger(__name__) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 23688430b7..3147fb6827 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections +import logging +import re + +from canonicaljson import json + from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.storage.search import SearchStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from ._base import SQLBaseStore -from .engines import PostgresEngine, Sqlite3Engine - -import collections -import logging -import ujson as json -import re - logger = logging.getLogger(__name__) @@ -40,7 +40,138 @@ RatelimitOverride = collections.namedtuple( ) -class RoomStore(SQLBaseStore): +class RoomWorkerStore(SQLBaseStore): + def get_public_room_ids(self): + return self._simple_select_onecol( + table="rooms", + keyvalues={ + "is_public": True, + }, + retcol="room_id", + desc="get_public_room_ids", + ) + + @cached(num_args=2, max_entries=100) + def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): + """Get pulbic rooms for a particular list, or across all lists. + + Args: + stream_id (int) + network_tuple (ThirdPartyInstanceID): The list to use (None, None) + means the main list, None means all lsits. + """ + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, + stream_id, network_tuple=network_tuple + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, + network_tuple): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn( + txn, stream_id, network_tuple=network_tuple + ).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): + if network_tuple: + # We want to get from a particular list. No aggregation required. + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? %s + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + if network_tuple.appservice_id is not None: + txn.execute( + sql % ("AND appservice_id = ? AND network_id = ?",), + (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + ) + else: + txn.execute( + sql % ("AND appservice_id IS NULL",), + (stream_id,) + ) + return dict(txn) + else: + # We want to get from all lists, so we need to aggregate the results + + logger.info("Executing full list") + + sql = (""" + SELECT room_id, visibility + FROM public_room_list_stream + INNER JOIN ( + SELECT + room_id, max(stream_id) AS stream_id, appservice_id, + network_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id, appservice_id, network_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute( + sql, + (stream_id,) + ) + + results = {} + # A room is visible if its visible on any list. + for room_id, visibility in txn: + results[room_id] = bool(visibility) or results.get(room_id, False) + + return results + + def get_public_room_changes(self, prev_stream_id, new_stream_id, + network_tuple): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn( + txn, prev_stream_id, network_tuple + ) + + now_rooms_dict = self.get_published_at_stream_id_txn( + txn, new_stream_id, network_tuple + ) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis + ) + + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms + + return newly_visible, newly_unpublished + + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self._simple_select_one_onecol( + table="blocked_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + +class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): @@ -227,16 +358,6 @@ class RoomStore(SQLBaseStore): ) self.hs.get_notifier().on_new_replication_data() - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={ - "is_public": True, - }, - retcol="room_id", - desc="get_public_room_ids", - ) - def get_room_count(self): """Retrieve a list of all rooms """ @@ -263,8 +384,8 @@ class RoomStore(SQLBaseStore): }, ) - self._store_event_search_txn( - txn, event, "content.topic", event.content["topic"] + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"], ) def _store_room_name_txn(self, txn, event): @@ -279,14 +400,14 @@ class RoomStore(SQLBaseStore): } ) - self._store_event_search_txn( - txn, event, "content.name", event.content["name"] + self.store_event_search_txn( + txn, event, "content.name", event.content["name"], ) def _store_room_message_txn(self, txn, event): if hasattr(event, "content") and "body" in event.content: - self._store_event_search_txn( - txn, event, "content.body", event.content["body"] + self.store_event_search_txn( + txn, event, "content.body", event.content["body"], ) def _store_history_visibility_txn(self, txn, event): @@ -308,31 +429,6 @@ class RoomStore(SQLBaseStore): event.content[key] )) - def _store_event_search_txn(self, txn, event, key, value): - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - txn.execute( - sql, - ( - event.event_id, event.room_id, key, value, - event.internal_metadata.stream_ordering, - event.origin_server_ts, - ) - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - txn.execute(sql, (event.event_id, event.room_id, key, value,)) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - def add_event_report(self, room_id, event_id, user_id, reason, content, received_ts): next_id = self._event_reports_id_gen.get_next() @@ -353,113 +449,6 @@ class RoomStore(SQLBaseStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. - - Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. - """ - return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple - ) - - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } - - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. - - sql = (""" - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """) - - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) - ) - else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) - return dict(txn) - else: - # We want to get from all lists, so we need to aggregate the results - - logger.info("Executing full list") - - sql = (""" - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """) - - txn.execute( - sql, - (stream_id,) - ) - - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn: - results[room_id] = bool(visibility) or results.get(room_id, False) - - return results - - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) - - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) - - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis - ) - - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms - - return newly_visible, newly_unpublished - - return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn - ) - def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(txn): sql = (""" @@ -509,18 +498,6 @@ class RoomStore(SQLBaseStore): else: defer.returnValue(None) - @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( - table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, - retcol="1", - allow_none=True, - desc="is_room_blocked", - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( @@ -531,75 +508,121 @@ class RoomStore(SQLBaseStore): }, desc="block_room", ) - self.is_room_blocked.invalidate((room_id,)) + yield self.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, (room_id,), + ) + + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines the associated media """ - def _get_media_ids_in_room(txn): - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 - next_token = self.get_current_events_token() + 1 + # Now update all the tables to set the quarantined_by flag - total_media_quarantined = 0 + txn.executemany(""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, ((quarantined_by, media_id) for media_id in local_mxcs)) - while next_token: - sql = """ - SELECT stream_ordering, content FROM events - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? + txn.executemany( """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - local_media_mxcs = [] - remote_media_mxcs = [] - for stream_ordering, content_json in txn: - next_token = stream_ordering - content = json.loads(content_json) - - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - # Now update all the tables to set the quarantined_by flag - - txn.executemany(""" - UPDATE local_media_repository + UPDATE remote_media_cache SET quarantined_by = ? - WHERE media_id = ? - """, ((quarantined_by, media_id) for media_id in local_media_mxcs)) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_media_mxcs - ) + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs ) + ) - total_media_quarantined += len(local_media_mxcs) - total_media_quarantined += len(remote_media_mxcs) + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) return total_media_quarantined - return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room) + return self.runInteraction( + "quarantine_media_in_room", + _quarantine_media_in_room_txn, + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 457ca288d0..01697ab2c9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,22 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - +import logging from collections import namedtuple -from ._base import SQLBaseStore +from six import iteritems, itervalues + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.events import EventsWorkerStore +from synapse.types import get_domain_from_id from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii -from synapse.api.constants import Membership, EventTypes -from synapse.types import get_domain_from_id - -import logging -import ujson as json - logger = logging.getLogger(__name__) @@ -37,6 +39,11 @@ RoomsForUser = namedtuple( ("room_id", "sender", "membership", "event_id", "stream_ordering") ) +GetRoomsForUserWithStreamOrdering = namedtuple( + "_GetRoomsForUserWithStreamOrdering", + ("room_id", "stream_ordering",) +) + # We store this using a namedtuple so that we save about 3x space over using a # dict. @@ -48,97 +55,7 @@ ProfileInfo = namedtuple( _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -class RoomMemberStore(SQLBaseStore): - def __init__(self, hs): - super(RoomMemberStore, self).__init__(hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ] - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. - # The only current event that can also be an outlier is if its an - # invite that has come in across federation. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_invite_from_remote() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - } - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - +class RoomMemberWorkerStore(EventsWorkerStore): @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) def get_hosts_in_room(self, room_id, cache_context): """Returns the set of all hosts currently in the room @@ -270,12 +187,32 @@ class RoomMemberStore(SQLBaseStore): return results @cachedInlineCallbacks(max_entries=500000, iterable=True) - def get_rooms_for_user(self, user_id): + def get_rooms_for_user_with_stream_ordering(self, user_id): """Returns a set of room_ids the user is currently joined to + + Args: + user_id (str) + + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. """ rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + defer.returnValue(frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + )) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate, + ) defer.returnValue(frozenset(r.room_id for r in rooms)) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) @@ -295,89 +232,7 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(user_who_share_room) - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - txn.call_after(self.was_forgotten_at.invalidate_all) - txn.call_after(self.did_forget.invalidate, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.who_forgot_in_room, (room_id,) - ) - return self.runInteraction("forget_membership", f) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) - defer.returnValue(count == 0) - - @cachedInlineCallbacks(num_args=3) - def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at - event_id. - - event_id must be a membership event.""" - def f(txn): - sql = ( - "SELECT" - " forgotten" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " event_id = ?" - ) - txn.execute(sql, (user_id, room_id, event_id)) - rows = txn.fetchall() - return rows[0][0] - forgot = yield self.runInteraction("did_forget_membership_at", f) - defer.returnValue(forgot == 1) - - @cached() - def who_forgot_in_room(self, room_id): - return self._simple_select_list( - table="room_memberships", - retcols=("user_id", "event_id"), - keyvalues={ - "room_id": room_id, - "forgotten": 1, - }, - desc="who_forgot" - ) - + @defer.inlineCallbacks def get_joined_users_from_context(self, event, context): state_group = context.state_group if not state_group: @@ -387,11 +242,13 @@ class RoomMemberStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context, ) + defer.returnValue(result) def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group @@ -419,7 +276,7 @@ class RoomMemberStore(SQLBaseStore): users_in_room = {} member_event_ids = [ e_id - for key, e_id in current_state_ids.iteritems() + for key, e_id in iteritems(current_state_ids) if key[0] == EventTypes.Member ] @@ -436,7 +293,7 @@ class RoomMemberStore(SQLBaseStore): users_in_room = dict(prev_res) member_event_ids = [ e_id - for key, e_id in context.delta_ids.iteritems() + for key, e_id in iteritems(context.delta_ids) if key[0] == EventTypes.Member ] for etype, state_key in context.delta_ids: @@ -533,6 +390,46 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(True) + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -560,6 +457,144 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(joined_hosts) + @cached(max_entries=10000) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + +class RoomMemberStore(RoomMemberWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ] + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, event.internal_metadata.stream_ordering + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + txn.call_after(self.did_forget.invalidate, (user_id, room_id)) + return self.runInteraction("forget_membership", f) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) + defer.returnValue(count == 0) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( @@ -573,8 +608,9 @@ class RoomMemberStore(SQLBaseStore): def add_membership_profile_txn(txn): sql = (""" - SELECT stream_ordering, event_id, events.room_id, content + SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events + INNER JOIN event_json USING (event_id) INNER JOIN room_memberships USING (event_id) WHERE ? <= stream_ordering AND stream_ordering < ? AND type = 'm.room.member' @@ -595,8 +631,9 @@ class RoomMemberStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json['content'] + except Exception: continue display_name = content.get("displayname", None) @@ -635,10 +672,6 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(result) - @cached(max_entries=10000, iterable=True) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - class _JoinedHostsCache(object): """Cache for joined hosts in a room that is optimised to handle updates @@ -671,7 +704,7 @@ class _JoinedHostsCache(object): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + for (typ, state_key), event_id in iteritems(state_entry.delta_ids): if typ != EventTypes.Member: continue @@ -701,7 +734,7 @@ class _JoinedHostsCache(object): self.state_group = state_entry.state_group else: self.state_group = object() - self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) defer.returnValue(frozenset(self.hosts_to_joined_users)) def __len__(self): diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 8755bb2e49..4d725b92fe 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging +import simplejson as json + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 4269ac69ad..4b2ffd35fd 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -14,10 +14,10 @@ import logging -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +import simplejson -import ujson +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index 71b12a2731..414f9f5aa0 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -14,9 +14,9 @@ import logging -from synapse.storage.prepare_database import get_statements +import simplejson -import ujson +from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 5b7d8d1ab5..ef7ec34346 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from synapse.config.appservice import load_appservices +from six.moves import range + +from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) @@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except: + except Exception: # Maybe we already added the column? Hope so... pass @@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py index 470ae0c005..7d8ca5f93f 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/schema/delta/31/search_update.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson + from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements -import logging -import ujson - logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py index 83066cccc9..bff1256a7b 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements - import logging -import ujson + +import simplejson + +from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py index 55ae43f395..9754d3ccfb 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -14,7 +14,6 @@ import time - ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py index 3b63a1562d..cf09e43e2b 100644 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ b/synapse/storage/schema/delta/34/cache_stream.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py index 033144341c..67d505e68b 100644 --- a/synapse/storage/schema/delta/34/received_txn_purge.py +++ b/synapse/storage/schema/delta/34/received_txn_purge.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py index 81948e3431..0ffab10b6f 100644 --- a/synapse/storage/schema/delta/34/sent_txn_purge.py +++ b/synapse/storage/schema/delta/34/sent_txn_purge.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py index 20ad8bd5a6..a377884169 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine - import logging +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + logger = logging.getLogger(__name__) DROP_INDICES = """ diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql index f090a7b75a..515e6b8e84 100644 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql @@ -13,5 +13,7 @@ * limitations under the License. */ - INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_postgres_gist', '{}'); +-- We no longer do this given we back it out again in schema 47 + +-- INSERT into background_updates (update_name, progress_json) +-- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py index ea6a18196d..506f326f4d 100644 --- a/synapse/storage/schema/delta/42/user_dir.py +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -14,8 +14,8 @@ import logging -from synapse.storage.prepare_database import get_statements from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql index 4501d90cbb..ee7062abe4 100644 --- a/synapse/storage/schema/delta/43/user_share.sql +++ b/synapse/storage/schema/delta/43/user_share.sql @@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); --- Make sure that we popualte the table initially +-- Make sure that we populate the table initially UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/schema/delta/45/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql index bb225dafbf..68c48a89a9 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,5 +13,5 @@ * limitations under the License. */ -INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/47/last_access_media.sql index 290bd6da86..f505fb22b5 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/47/last_access_media.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2018 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql index 34db0cf12b..31d7a817eb 100644 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql @@ -1,4 +1,4 @@ -/* Copyright 2015, 2016 OpenMarket Ltd +/* Copyright 2018 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,5 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql new file mode 100644 index 0000000000..edccf4a96f --- /dev/null +++ b/synapse/storage/schema/delta/47/push_actions_staging.sql @@ -0,0 +1,28 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Temporary staging area for push actions that have been calculated for an +-- event, but the event hasn't yet been persisted. +-- When the event is persisted the rows are moved over to the +-- event_push_actions table. +CREATE TABLE event_push_actions_staging ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL +); + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..f6766501d2 --- /dev/null +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -0,0 +1,37 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute( + "CREATE SEQUENCE state_group_id_seq START WITH %s", + (start_val, ), + ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql new file mode 100644 index 0000000000..5237491506 --- /dev/null +++ b/synapse/storage/schema/delta/48/add_user_consent.sql @@ -0,0 +1,18 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record the version of the privacy policy the user has consented to + */ +ALTER TABLE users ADD COLUMN consent_version TEXT; diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql new file mode 100644 index 0000000000..e9013a6969 --- /dev/null +++ b/synapse/storage/schema/delta/48/deactivated_users.sql @@ -0,0 +1,25 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Store any accounts that have been requested to be deactivated. + * We part the account from all the rooms its in when its + * deactivated. This can take some time and synapse may be restarted + * before it completes, so store the user IDs here until the process + * is complete. + */ +CREATE TABLE users_pending_deactivation ( + user_id TEXT NOT NULL +); diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..2233af87d7 --- /dev/null +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,57 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute(""" + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + cur.execute(""" + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql new file mode 100644 index 0000000000..14dcf18d73 --- /dev/null +++ b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql @@ -0,0 +1,20 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record whether we have sent a server notice about consenting to the + * privacy policy. Specifically records the version of the policy we sent + * a message about. + */ +ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT; diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql new file mode 100644 index 0000000000..3dd478196f --- /dev/null +++ b/synapse/storage/schema/delta/49/add_user_daily_visits.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, + device_id TEXT, + timestamp BIGINT NOT NULL ); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql new file mode 100644 index 0000000000..3a4ed59b5b --- /dev/null +++ b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_only_index', '{}'); diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql new file mode 100644 index 0000000000..c93ae47532 --- /dev/null +++ b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +INSERT into background_updates (update_name, progress_json) + VALUES ('users_creation_ts', '{}'); diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql new file mode 100644 index 0000000000..5d8641a9ab --- /dev/null +++ b/synapse/storage/schema/delta/50/erasure_store.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of users who have requested that their details be erased +CREATE TABLE erased_users ( + user_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 8f2b3c4435..d5b5df93e6 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -13,28 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import re +from collections import namedtuple + +from six import string_types + +from canonicaljson import json + from twisted.internet import defer -from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import logging -import re -import ujson as json - +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) +SearchEntry = namedtuple('SearchEntry', [ + 'key', 'value', 'event_id', 'room_id', 'stream_ordering', + 'origin_server_ts', +]) + class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" + EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, hs): - super(SearchStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -42,23 +52,35 @@ class SearchStore(BackgroundUpdateStore): self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) - self.register_background_update_handler( + + # we used to have a background update to turn the GIN index into a + # GIST one; we no longer do that (obviously) because we actually want + # a GIN index. However, it's possible that some people might still have + # the background update queued, so we register a handler to clear the + # background update. + self.register_noop_background_update( self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, - self._background_reindex_gist_search + ) + + self.register_background_update_handler( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, + self._background_reindex_gin_search ) @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): + # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id, room_id, type, content FROM events" + "SELECT stream_ordering, event_id, room_id, type, json, " + " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -67,6 +89,10 @@ class SearchStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + # we could stream straight from the results into + # store_search_entries_txn with a generator function, but that + # would mean having two cursors open on the database at once. + # Instead we just build a list of results. rows = self.cursor_to_dict(txn) if not rows: return 0 @@ -79,9 +105,12 @@ class SearchStore(BackgroundUpdateStore): event_id = row["event_id"] room_id = row["room_id"] etype = row["type"] + stream_ordering = row["stream_ordering"] + origin_server_ts = row["origin_server_ts"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: continue if etype == "m.room.message": @@ -93,35 +122,28 @@ class SearchStore(BackgroundUpdateStore): elif etype == "m.room.name": key = "content.name" value = content["name"] + else: + raise Exception("unexpected event type %s" % etype) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. continue - if not isinstance(value, basestring): + if not isinstance(value, string_types): # If the event body, name or topic isn't a string # then skip over it continue - event_search_rows.append((event_id, room_id, key, value)) - - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, vector)" - " VALUES (?,?,?,to_tsvector('english', ?))" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") + event_search_rows.append(SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + )) - for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): - clump = event_search_rows[index:index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + self.store_search_entries_txn(txn, event_search_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -145,25 +167,48 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(result) @defer.inlineCallbacks - def _background_reindex_gist_search(self, progress, batch_size): + def _background_reindex_gin_search(self, progress, batch_size): + """This handles old synapses which used GIST indexes, if any; + converting them back to be GIN as per the actual schema. + """ + def create_index(conn): conn.rollback() - conn.set_session(autocommit=True) - c = conn.cursor() - c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist" - " ON event_search USING GIST (vector)" - ) + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + conn.set_session(autocommit=True) - c.execute("DROP INDEX event_search_fts_idx") + try: + c = conn.cursor() - conn.set_session(autocommit=False) + # if we skipped the conversion to GIST, we may already/still + # have an event_search_fts_idx; unfortunately postgres 9.4 + # doesn't support CREATE INDEX IF EXISTS so we just catch the + # exception and ignore it. + import psycopg2 + try: + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx" + " ON event_search USING GIN (vector)" + ) + except psycopg2.ProgrammingError as e: + logger.warn( + "Ignoring error %r when trying to switch from GIST to GIN", + e + ) + + # we should now be able to delete the GIST index. + c.execute( + "DROP INDEX IF EXISTS event_search_fts_idx_gist" + ) + finally: + conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): yield self.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) defer.returnValue(1) @defer.inlineCallbacks @@ -242,6 +287,85 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(num_rows) + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store_search_entries_txn( + txn, + (SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ),), + ) + + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + entry.stream_ordering, entry.origin_server_ts, + ) for entry in entries) + + # inserts to a GIN index are normally batched up into a pending + # list, and then all committed together once the list gets to a + # certain size. The trouble with that is that postgres (pre-9.5) + # uses work_mem to determine the length of the list, and work_mem + # is typically very large. + # + # We therefore reduce work_mem while we do the insert. + # + # (postgres 9.5 uses the separate gin_pending_list_limit setting, + # so doesn't suffer the same problem, but changing work_mem will + # be harmless) + # + # Note that we don't need to worry about restoring it on + # exception, because exceptions will cause the transaction to be + # rolled back, including the effects of the SET command. + # + # Also: we use SET rather than SET LOCAL because there's lots of + # other stuff going on in this transaction, which want to have the + # normal work_mem setting. + + txn.execute("SET work_mem='256kB'") + txn.executemany(sql, args) + txn.execute("RESET work_mem") + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + ) for entry in entries) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. @@ -326,7 +450,7 @@ class SearchStore(BackgroundUpdateStore): "search_msgs", self.cursor_to_dict, sql, *args ) - results = filter(lambda row: row["room_id"] in room_ids, results) + results = list(filter(lambda row: row["room_id"] in room_ids, results)) events = yield self._get_events([r["event_id"] for r in results]) @@ -407,7 +531,7 @@ class SearchStore(BackgroundUpdateStore): origin_server_ts, stream = pagination_token.split(",") origin_server_ts = int(origin_server_ts) stream = int(stream) - except: + except Exception: raise SynapseError(400, "Invalid pagination token") clauses.append( @@ -481,7 +605,7 @@ class SearchStore(BackgroundUpdateStore): "search_rooms", self.cursor_to_dict, sql, *args ) - results = filter(lambda row: row["room_id"] in room_ids, results) + results = list(filter(lambda row: row["room_id"] in room_ids, results)) events = yield self._get_events([r["event_id"] for r in results]) diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 67d5d9969a..470212aa2a 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -13,21 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from ._base import SQLBaseStore +import six from unpaddedbase64 import encode_base64 + +from twisted.internet import defer + from synapse.crypto.event_signing import compute_event_reference_hash from synapse.util.caches.descriptors import cached, cachedList +from ._base import SQLBaseStore + +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview -class SignatureStore(SQLBaseStore): - """Persistence for event signatures and hashes""" +class SignatureWorkerStore(SQLBaseStore): @cached() def get_event_reference_hash(self, event_id): - return self._get_event_reference_hashes_txn(event_id) + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() @cachedList(cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1) @@ -56,7 +66,7 @@ class SignatureStore(SQLBaseStore): for e_id, h in hashes.items() } - defer.returnValue(hashes.items()) + defer.returnValue(list(hashes.items())) def _get_event_reference_hashes_txn(self, txn, event_id): """Get all the hashes for a given PDU. @@ -74,6 +84,10 @@ class SignatureStore(SQLBaseStore): txn.execute(query, (event_id, )) return {k: v for k, v in txn} + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" + def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU Args: @@ -87,7 +101,7 @@ class SignatureStore(SQLBaseStore): vals.append({ "event_id": event.event_id, "algorithm": ref_alg, - "hash": buffer(ref_hash_bytes), + "hash": db_binary_type(ref_hash_bytes), }) self._simple_insert_many_txn( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5673e4aa96..89a05c4618 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,16 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.util.stringutils import to_ascii -from synapse.storage.engines import PostgresEngine +import logging +from collections import namedtuple + +from six import iteritems, itervalues +from six.moves import range from twisted.internet import defer -from collections import namedtuple -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -40,45 +46,19 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, hs): - super(StateStore, self).__init__(hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") ) @cached(max_entries=100000, iterable=True) @@ -158,12 +138,26 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the state IDs for the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + defer.returnValue(group_to_state[state_group]) + + @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids @@ -176,199 +170,27 @@ class StateStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.itervalues() - for ev_id in group_ids.itervalues() + ev_id for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.itervalues() + state_event_map[v] for v in itervalues(event_id_map) if v in state_event_map ] - for group, event_id_map in group_to_ids.iteritems() + for group, event_id_map in iteritems(group_to_ids) }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": context.prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.iteritems() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.iteritems() - ], - ) - - # Prefill the state group cache with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), - full=True, - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.iteritems() - ], - ) - - for event_id, state_group_id in state_groups.iteritems(): - txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ results = {} - chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] + chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", @@ -422,6 +244,9 @@ class StateStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -441,10 +266,19 @@ class StateStore(SQLBaseStore): key = (typ, state_key) results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.append(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -461,7 +295,7 @@ class StateStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" @@ -474,9 +308,17 @@ class StateStore(SQLBaseStore): if (typ, state_key) not in results[group] ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( @@ -509,21 +351,21 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].iteritems() + for k, v in iteritems(group_to_state[group]) if v in state_event_map } - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -546,12 +388,12 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -665,7 +507,7 @@ class StateStore(SQLBaseStore): got_all = is_all or not missing_types return { - k: v for k, v in state_dict_ids.iteritems() + k: v for k, v in iteritems(state_dict_ids) if include(k[0], k[1]) }, missing_types, got_all @@ -685,10 +527,23 @@ class StateStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): - """Given list of groups returns dict of group -> list of state events - with matching types. `types` is a list of `(type, state_key)`, where - a `state_key` of None matches all state_keys. If `types` is None then - all events are returned. + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. """ if types: types = frozenset(types) @@ -697,7 +552,7 @@ class StateStore(SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, _, got_all = self._get_some_state_from_cache( - group, types + group, types, ) results[group] = state_dict_ids @@ -718,32 +573,266 @@ class StateStore(SQLBaseStore): # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence + # the DictionaryCache knows if it has *all* the state, but + # does not know if it has all of the keys of a particular type, + # which makes wildcard lookups expensive unless we have a complete + # cache. Hence, if we are doing a wildcard lookup, populate the + # cache fully so that we can do an efficient lookup next time. + + if types and any(k is None for (t, k) in types): + types_to_fetch = None + else: + types_to_fetch = types + group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types + missing_groups, types_to_fetch, ) - # Now we want to update the cache with all the things we fetched - # from the database. - for group, group_state_dict in group_to_state_dict.iteritems(): + for group, group_state_dict in iteritems(group_to_state_dict): state_dict = results[group] - state_dict.update( - ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) - for k, v in group_state_dict.iteritems() - ) - + # update the result, filtering by `types`. + if types: + for k, v in iteritems(group_state_dict): + (typ, _) = k + if k in types or (typ, None) in types: + state_dict[k] = v + else: + state_dict.update(group_state_dict) + + # update the cache with all the things we fetched from the + # database. self._state_group_cache.update( cache_seq_num, key=group, - value=state_dict, - full=(types is None), - known_absent=types, + value=group_state_dict, + fetched_keys=types_to_fetch, ) defer.returnValue(results) - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": state_group, + "room_id": room_id, + "event_id": event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_state_ids), + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): @@ -767,7 +856,7 @@ class StateStore(SQLBaseStore): def reindex_txn(txn): new_last_state_group = last_state_group - for count in xrange(batch_size): + for count in range(batch_size): txn.execute( "SELECT id, room_id FROM state_groups" " WHERE ? < id AND id <= ?" @@ -825,7 +914,7 @@ class StateStore(SQLBaseStore): # of keys delta_state = { - key: value for key, value in curr_state.iteritems() + key: value for key, value in iteritems(curr_state) if prev_state.get(key, None) != value } @@ -865,7 +954,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.iteritems() + for key, state_id in iteritems(delta_state) ], ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index dddd5fc0e7..66856342f0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -33,17 +33,20 @@ what sort order was used: and stream ordering columns respectively. """ -from twisted.internet import defer +import abc +import logging +from collections import namedtuple -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached -from synapse.api.constants import EventTypes -from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from six.moves import range -import logging +from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.events import EventsWorkerStore +from synapse.types import RoomStreamToken +from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -55,6 +58,12 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" +# Used as return values for pagination APIs +_EventDictReturn = namedtuple("_EventDictReturn", ( + "event_id", "topological_ordering", "stream_ordering", +)) + + def lower_bound(token, engine, inclusive=False): inclusive = "=" if inclusive else "" if token.topological is None: @@ -143,81 +152,41 @@ def filter_to_clause(event_filter): return " AND ".join(clauses), args -class StreamStore(SQLBaseStore): - @defer.inlineCallbacks - def get_appservice_room_stream(self, service, from_key, to_key, limit=0): - # NB this lives here instead of appservice.py so we can reuse the - # 'private' StreamToken class in this file. - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - defer.returnValue(([], to_key)) - return +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_room_max_stream_ordering` and `get_room_min_stream_ordering` + which can be called in the initializer. + """ - # select all the events between from/to with a sensible limit - sql = ( - "SELECT e.event_id, e.room_id, e.type, s.state_key, " - "e.stream_ordering FROM events AS e " - "LEFT JOIN state_events as s ON " - "e.event_id = s.event_id " - "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "limit": limit - } + __metaclass__ = abc.ABCMeta - def f(txn): - # pull out all the events between the tokens - txn.execute(sql, (from_id.stream, to_id.stream,)) - rows = self.cursor_to_dict(txn) - - # Logic: - # - We want ALL events which match the AS room_id regex - # - We want ALL events which match the rooms represented by the AS - # room_alias regex - # - We want ALL events for rooms that AS users have joined. - # This is currently supported via get_app_service_rooms (which is - # used for the Notifier listener rooms). We can't reasonably make a - # SQL query for these room IDs, so we'll pull all the events between - # from/to and filter in python. - rooms_for_as = self._get_app_service_rooms_txn(txn, service) - room_ids_for_as = [r.room_id for r in rooms_for_as] - - def app_service_interested(row): - if row["room_id"] in room_ids_for_as: - return True - - if row["type"] == EventTypes.Member: - if service.is_interested_in_user(row.get("state_key")): - return True - return False - - return [r for r in rows if app_service_interested(r)] - - rows = yield self.runInteraction("get_appservice_room_stream", f) + def __init__(self, db_conn, hs): + super(StreamWorkerStore, self).__init__(db_conn, hs) - ret = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._stream_order_on_start = self.get_room_max_stream_ordering() - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + @abc.abstractmethod + def get_room_max_stream_ordering(self): + raise NotImplementedError() - defer.returnValue((ret, key)) + @abc.abstractmethod + def get_room_min_stream_ordering(self): + raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, @@ -233,13 +202,14 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.get_room_events_stream_for_room)( + for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): + res = yield make_deferred_yieldable(defer.gatherResults([ + run_in_background( + self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ])) + ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) @@ -261,54 +231,55 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): - # Note: If from_key is None then we return in topological order. This - # is because in that case we're using this as a "get the last few messages - # in a room" function, rather than "get new messages since last sync" - if from_key is not None: - from_id = RoomStreamToken.parse_stream_token(from_key).stream - else: - from_id = None - to_id = RoomStreamToken.parse_stream_token(to_key).stream + """Get new room events in stream ordering since `from_key`. + + Args: + room_id (str) + from_key (str): Token from which no events are returned before + to_key (str): Token from which no events are returned after. (This + is typically the current stream token) + limit (int): Maximum number of events to return + order (str): Either "DESC" or "ASC". Determines which events are + returned when the result is limited. If "DESC" then the most + recent `limit` events are returned, otherwise returns the + oldest `limit` events. + + Returns: + Deferred[tuple[list[FrozenEvent], str]]: Returns the list of + events (in ascending order) and the token from the start of + the chunk of events returned. + """ if from_key == to_key: defer.returnValue(([], from_key)) - if from_id: - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) - - if not has_changed: - defer.returnValue(([], from_key)) + from_id = RoomStreamToken.parse_stream_token(from_key).stream + to_id = RoomStreamToken.parse_stream_token(to_key).stream - def f(txn): - if from_id is not None: - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - else: - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering <= ?" - " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?" - ) % (order, order,) - txn.execute(sql, (room_id, to_id, limit)) + has_changed = yield self._events_stream_cache.has_entity_changed( + room_id, from_id + ) - rows = self.cursor_to_dict(txn) + if not has_changed: + defer.returnValue(([], from_key)) + def f(txn): + sql = ( + "SELECT event_id, stream_ordering FROM events WHERE" + " room_id = ?" + " AND not outlier" + " AND stream_ordering > ? AND stream_ordering <= ?" + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) + txn.execute(sql, (room_id, from_id, to_id, limit)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_room_events_stream_for_room", f) ret = yield self._get_events( - [r["event_id"] for r in rows], + [r.event_id for r in rows], get_prev_content=True ) @@ -318,7 +289,7 @@ class StreamStore(SQLBaseStore): ret.reverse() if rows: - key = "s%d" % min(r["stream_ordering"] for r in rows) + key = "s%d" % min(r.stream_ordering for r in rows) else: # Assume we didn't get anything because there was nothing to # get. @@ -328,10 +299,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): - if from_key is not None: - from_id = RoomStreamToken.parse_stream_token(from_key).stream - else: - from_id = None + from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: @@ -345,34 +313,24 @@ class StreamStore(SQLBaseStore): defer.returnValue([]) def f(txn): - if from_id is not None: - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id,)) - else: - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - ) - txn.execute(sql, (user_id, to_id,)) - rows = self.cursor_to_dict(txn) + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND e.stream_ordering > ? AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + ) + txn.execute(sql, (user_id, from_id, to_id,)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self._get_events( - [r["event_id"] for r in rows], + [r.event_id for r in rows], get_prev_content=True ) @@ -381,96 +339,28 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, event_filter=None): - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == 'b': - order = "DESC" - bounds = upper_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, lower_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - else: - order = "ASC" - bounds = lower_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, upper_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - if int(limit) > 0: - args.append(int(limit)) - limit_str = " LIMIT ?" - else: - limit_str = "" - - sql = ( - "SELECT * FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s %(limit)s" - ) % { - "bounds": bounds, - "order": order, - "limit": limit_str - } - - def f(txn): - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - if rows: - topo = rows[-1]["topological_ordering"] - toke = rows[-1]["stream_ordering"] - if direction == 'b': - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = str(RoomStreamToken(topo, toke)) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_key if to_key else from_key + def get_recent_events_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. - return rows, next_token, - - rows, token = yield self.runInteraction("paginate_room_events", f) - - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. - defer.returnValue((events, token)) + Returns: + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + events and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): rows, token = yield self.get_recent_event_ids_for_room( - room_id, limit, end_token, from_token + room_id, limit, end_token, ) logger.debug("stream before") events = yield self._get_events( - [r["event_id"] for r in rows], + [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") @@ -479,59 +369,62 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @cached(num_args=4) - def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None): - end_token = RoomStreamToken.parse_stream_token(end_token) + @defer.inlineCallbacks + def get_recent_event_ids_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. - if from_token is None: - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?" - " ORDER BY topological_ordering DESC, stream_ordering DESC" - " LIMIT ?" - ) - else: - from_token = RoomStreamToken.parse_stream_token(from_token) - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering > ?" - " AND stream_ordering <= ? AND outlier = ?" - " ORDER BY topological_ordering DESC, stream_ordering DESC" - " LIMIT ?" - ) + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. - def get_recent_events_for_room_txn(txn): - if from_token is None: - txn.execute(sql, (room_id, end_token.stream, False, limit,)) - else: - txn.execute(sql, ( - room_id, from_token.stream, end_token.stream, False, limit - )) + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + _EventDictReturn and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + # Allow a zero limit here, and no-op. + if limit == 0: + defer.returnValue(([], end_token)) - rows = self.cursor_to_dict(txn) + end_token = RoomStreamToken.parse(end_token) - rows.reverse() # As we selected with reverse ordering + rows, token = yield self.runInteraction( + "get_recent_event_ids_for_room", self._paginate_room_events_txn, + room_id, from_token=end_token, limit=limit, + ) - if rows: - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # since we are going backwards so we subtract one from the - # stream part. - topo = rows[0]["topological_ordering"] - toke = rows[0]["stream_ordering"] - 1 - start_token = str(RoomStreamToken(topo, toke)) + # We want to return the results in ascending order. + rows.reverse() - token = (start_token, str(end_token)) - else: - token = (str(end_token), str(end_token)) + defer.returnValue((rows, token)) + + def get_room_event_after_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or after a stream ordering + + Args: + room_id (str): + stream_ordering (int): - return rows, token + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering, )) + return txn.fetchone() return self.runInteraction( - "get_recent_events_for_room", get_recent_events_for_room_txn + "get_room_event_after_stream_ordering", _f, ) @defer.inlineCallbacks @@ -542,7 +435,7 @@ class StreamStore(SQLBaseStore): `room_id` causes it to return the current room specific topological token. """ - token = yield self._stream_id_gen.get_current_token() + token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: @@ -552,12 +445,6 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() - def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -615,10 +502,20 @@ class StreamStore(SQLBaseStore): @staticmethod def _set_before_and_after(events, rows, topo_order=True): + """Inserts ordering information to events' internal metadata from + the DB rows. + + Args: + events (list[FrozenEvent]) + rows (list[_EventDictReturn]) + topo_order (bool): Whether the events were ordered topologically + or by stream ordering. If true then all rows should have a non + null topological_ordering. + """ for event, row in zip(events, rows): - stream = row["stream_ordering"] - if topo_order: - topo = event.depth + stream = row.stream_ordering + if topo_order and row.topological_ordering: + topo = row.topological_ordering else: topo = None internal = event.internal_metadata @@ -690,87 +587,27 @@ class StreamStore(SQLBaseStore): retcols=["stream_ordering", "topological_ordering"], ) - token = RoomStreamToken( - results["topological_ordering"], + # Paginating backwards includes the event at the token, but paginating + # forward doesn't. + before_token = RoomStreamToken( + results["topological_ordering"] - 1, results["stream_ordering"], ) - if isinstance(self.database_engine, Sqlite3Engine): - # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)`` - # So we give pass it to SQLite3 as the UNION ALL of the two queries. - - query_before = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND topological_ordering < ?" - " UNION ALL" - " SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?" - " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" - ) - before_args = ( - room_id, token.topological, - room_id, token.topological, token.stream, - before_limit, - ) - - query_after = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND topological_ordering > ?" - " UNION ALL" - " SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?" - " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" - ) - after_args = ( - room_id, token.topological, - room_id, token.topological, token.stream, - after_limit, - ) - else: - query_before = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND %s" - " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" - ) % (upper_bound(token, self.database_engine, inclusive=False),) - - before_args = (room_id, before_limit) - - query_after = ( - "SELECT topological_ordering, stream_ordering, event_id FROM events" - " WHERE room_id = ? AND %s" - " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" - ) % (lower_bound(token, self.database_engine, inclusive=False),) - - after_args = (room_id, after_limit) - - txn.execute(query_before, before_args) - - rows = self.cursor_to_dict(txn) - events_before = [r["event_id"] for r in rows] - - if rows: - start_token = str(RoomStreamToken( - rows[0]["topological_ordering"], - rows[0]["stream_ordering"] - 1, - )) - else: - start_token = str(RoomStreamToken( - token.topological, - token.stream - 1, - )) - - txn.execute(query_after, after_args) + after_token = RoomStreamToken( + results["topological_ordering"], + results["stream_ordering"], + ) - rows = self.cursor_to_dict(txn) - events_after = [r["event_id"] for r in rows] + rows, start_token = self._paginate_room_events_txn( + txn, room_id, before_token, direction='b', limit=before_limit, + ) + events_before = [r.event_id for r in rows] - if rows: - end_token = str(RoomStreamToken( - rows[-1]["topological_ordering"], - rows[-1]["stream_ordering"], - )) - else: - end_token = str(token) + rows, end_token = self._paginate_room_events_txn( + txn, room_id, after_token, direction='f', limit=after_limit, + ) + events_after = [r.event_id for r in rows] return { "before": { @@ -832,3 +669,139 @@ class StreamStore(SQLBaseStore): def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None, + direction='b', limit=-1, event_filter=None): + """Returns list of events before or after a given token. + + Args: + txn + room_id (str) + from_token (RoomStreamToken): The token used to stream from + to_token (RoomStreamToken|None): A token which if given limits the + results to only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns the results + as a list of _EventDictReturn and a token that points to the end + of the result set. + """ + + assert int(limit) >= 0 + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == 'b': + order = "DESC" + bounds = upper_bound( + from_token, self.database_engine + ) + if to_token: + bounds = "%s AND %s" % (bounds, lower_bound( + to_token, self.database_engine + )) + else: + order = "ASC" + bounds = lower_bound( + from_token, self.database_engine + ) + if to_token: + bounds = "%s AND %s" % (bounds, upper_bound( + to_token, self.database_engine + )) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + args.append(int(limit)) + + sql = ( + "SELECT event_id, topological_ordering, stream_ordering" + " FROM events" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s LIMIT ?" + ) % { + "bounds": bounds, + "order": order, + } + + txn.execute(sql, args) + + rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + + if rows: + topo = rows[-1].topological_ordering + toke = rows[-1].stream_ordering + if direction == 'b': + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = RoomStreamToken(topo, toke) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_token if to_token else from_token + + return rows, str(next_token), + + @defer.inlineCallbacks + def paginate_room_events(self, room_id, from_key, to_key=None, + direction='b', limit=-1, event_filter=None): + """Returns list of events before or after a given token. + + Args: + room_id (str) + from_key (str): The token used to stream from + to_key (str|None): A token which if given limits the results to + only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. Zero or less + means no limit. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + tuple[list[dict], str]: Returns the results as a list of dicts and + a token that points to the end of the result set. The dicts have + the keys "event_id", "topological_ordering" and "stream_orderign". + """ + + from_key = RoomStreamToken.parse(from_key) + if to_key: + to_key = RoomStreamToken.parse(to_key) + + rows, token = yield self.runInteraction( + "paginate_room_events", self._paginate_room_events_txn, + room_id, from_key, to_key, direction, limit, event_filter, + ) + + events = yield self._get_events( + [r.event_id for r in rows], + get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index bff73f3f04..0f657b2bd3 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,25 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached -from twisted.internet import defer - -import ujson as json import logging -logger = logging.getLogger(__name__) +from six.moves import range +from canonicaljson import json -class TagsStore(SQLBaseStore): - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream +from twisted.internet import defer + +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() +class TagsWorkerStore(AccountDataWorkerStore): @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -104,7 +101,7 @@ class TagsStore(SQLBaseStore): batch_size = 50 results = [] - for i in xrange(0, len(tag_ids), batch_size): + for i in range(0, len(tag_ids), batch_size): tags = yield self.runInteraction( "get_all_updated_tag_content", get_tag_content, @@ -170,6 +167,8 @@ class TagsStore(SQLBaseStore): row["tag"]: json.loads(row["content"]) for row in rows }) + +class TagsStore(TagsWorkerStore): @defer.inlineCallbacks def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 809fdd311f..c3bc94f56d 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,17 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached +import logging +from collections import namedtuple + +import six + +from canonicaljson import encode_canonical_json, json from twisted.internet import defer -from canonicaljson import encode_canonical_json +from synapse.util.caches.descriptors import cached -from collections import namedtuple +from ._base import SQLBaseStore -import logging -import ujson as json +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview logger = logging.getLogger(__name__) @@ -46,8 +54,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, hs): - super(TransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(TransactionStore, self).__init__(db_conn, hs) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) @@ -110,7 +118,7 @@ class TransactionStore(SQLBaseStore): "transaction_id": transaction_id, "origin": origin, "response_code": code, - "response_json": buffer(encode_canonical_json(response_dict)), + "response_json": db_binary_type(encode_canonical_json(response_dict)), "ts": self._clock.time_msec(), }, or_ignore=True, diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 2a4db3f03c..a8781b0e5d 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +import re -from ._base import SQLBaseStore +from six import iteritems + +from twisted.internet import defer -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -import re -import logging +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -63,7 +65,7 @@ class UserDirectoryStore(SQLBaseStore): user_ids (list(str)): Users to add """ yield self._simple_insert_many( - table="users_in_pubic_room", + table="users_in_public_rooms", values=[ { "user_id": user_id, @@ -100,7 +102,7 @@ class UserDirectoryStore(SQLBaseStore): user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), profile.display_name, ) - for user_id, profile in users_with_profile.iteritems() + for user_id, profile in iteritems(users_with_profile) ) elif isinstance(self.database_engine, Sqlite3Engine): sql = """ @@ -112,7 +114,7 @@ class UserDirectoryStore(SQLBaseStore): user_id, "%s %s" % (user_id, p.display_name,) if p.display_name else user_id ) - for user_id, p in users_with_profile.iteritems() + for user_id, p in iteritems(users_with_profile) ) else: # This should be unreachable. @@ -130,7 +132,7 @@ class UserDirectoryStore(SQLBaseStore): "display_name": profile.display_name, "avatar_url": profile.avatar_url, } - for user_id, profile in users_with_profile.iteritems() + for user_id, profile in iteritems(users_with_profile) ] ) for user_id in users_with_profile: @@ -164,7 +166,7 @@ class UserDirectoryStore(SQLBaseStore): ) if isinstance(self.database_engine, PostgresEngine): - # We weight the loclpart most highly, then display name and finally + # We weight the localpart most highly, then display name and finally # server name if new_entry: sql = """ @@ -219,7 +221,7 @@ class UserDirectoryStore(SQLBaseStore): @defer.inlineCallbacks def update_user_in_public_user_list(self, user_id, room_id): yield self._simple_update_one( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, updatevalues={"room_id": room_id}, desc="update_user_in_public_user_list", @@ -240,7 +242,7 @@ class UserDirectoryStore(SQLBaseStore): ) self._simple_delete_txn( txn, - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, ) txn.call_after( @@ -256,18 +258,18 @@ class UserDirectoryStore(SQLBaseStore): @defer.inlineCallbacks def remove_from_user_in_public_room(self, user_id): yield self._simple_delete( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, desc="remove_from_user_in_public_room", ) self.get_user_in_public_room.invalidate((user_id,)) def get_users_in_public_due_to_room(self, room_id): - """Get all user_ids that are in the room directory becuase they're + """Get all user_ids that are in the room directory because they're in the given room_id """ return self._simple_select_onecol( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_public_due_to_room", @@ -275,7 +277,7 @@ class UserDirectoryStore(SQLBaseStore): @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): - """Get all user_ids that are in the room directory becuase they're + """Get all user_ids that are in the room directory because they're in the given room_id """ user_ids_dir = yield self._simple_select_onecol( @@ -286,7 +288,7 @@ class UserDirectoryStore(SQLBaseStore): ) user_ids_pub = yield self._simple_select_onecol( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", @@ -317,6 +319,16 @@ class UserDirectoryStore(SQLBaseStore): rows = yield self._execute("get_all_rooms", None, sql) defer.returnValue([room_id for room_id, in rows]) + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): """Insert entries into the users_who_share_rooms table. The first user should be a local user. @@ -514,7 +526,7 @@ class UserDirectoryStore(SQLBaseStore): def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_pubic_room") + txn.execute("DELETE FROM users_in_public_rooms") txn.execute("DELETE FROM users_who_share_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) txn.call_after(self.get_user_in_public_room.invalidate_all) @@ -537,7 +549,7 @@ class UserDirectoryStore(SQLBaseStore): @cached() def get_user_in_public_room(self, user_id): return self._simple_select_one( - table="users_in_pubic_room", + table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcols=("room_id",), allow_none=True, @@ -629,6 +641,25 @@ class UserDirectoryStore(SQLBaseStore): ] } """ + + if self.hs.config.user_directory_search_all_users: + # make s.user_id null to keep the ordering algorithm happy + join_clause = """ + CROSS JOIN (SELECT NULL as user_id) AS s + """ + join_args = () + where_clause = "1=1" + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + join_args = (user_id,) + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -638,16 +669,12 @@ class UserDirectoryStore(SQLBaseStore): # The array of numbers are the weights for the various part of the # search: (domain, _, display name, localpart) sql = """ - SELECT d.user_id, display_name, avatar_url + SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_pubic_room AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND vector @@ to_tsquery('english', ?) ORDER BY (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -671,30 +698,26 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ - args = (user_id, full_query, exact_query, prefix_query, limit + 1,) + """ % (join_clause, where_clause) + args = join_args + (full_query, exact_query, prefix_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) sql = """ - SELECT d.user_id, display_name, avatar_url + SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users_in_pubic_room AS p USING (user_id) - LEFT JOIN ( - SELECT other_user_id AS user_id FROM users_who_share_rooms - WHERE user_id = ? AND share_private - ) AS s USING (user_id) + %s WHERE - (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + %s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ - args = (user_id, search_query, limit + 1) + """ % (join_clause, where_clause) + args = join_args + (search_query, limit + 1) else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -723,7 +746,7 @@ def _parse_query_sqlite(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* | %s)" % (result, result,) for result in results) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) def _parse_query_postgres(search_term): diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py new file mode 100644 index 0000000000..be013f4427 --- /dev/null +++ b/synapse/storage/user_erasure_store.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import operator + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList + + +class UserErasureWorkerStore(SQLBaseStore): + @cached() + def is_user_erased(self, user_id): + """ + Check if the given user id has requested erasure + + Args: + user_id (str): full user id to check + + Returns: + Deferred[bool]: True if the user has requested erasure + """ + return self._simple_select_onecol( + table="erased_users", + keyvalues={"user_id": user_id}, + retcol="1", + desc="is_user_erased", + ).addCallback(operator.truth) + + @cachedList( + cached_method_name="is_user_erased", + list_name="user_ids", + inlineCallbacks=True, + ) + def are_users_erased(self, user_ids): + """ + Checks which users in a list have requested erasure + + Args: + user_ids (iterable[str]): full user id to check + + Returns: + Deferred[dict[str, bool]]: + for each user, whether the user has requested erasure. + """ + # this serves the dual purpose of (a) making sure we can do len and + # iterate it multiple times, and (b) avoiding duplicates. + user_ids = tuple(set(user_ids)) + + def _get_erased_users(txn): + txn.execute( + "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % ( + ",".join("?" * len(user_ids)) + ), + user_ids, + ) + return set(r[0] for r in txn) + + erased_users = yield self.runInteraction( + "are_users_erased", _get_erased_users, + ) + res = dict((u, u in erased_users) for u in user_ids) + defer.returnValue(res) + + +class UserErasureStore(UserErasureWorkerStore): + def mark_user_erased(self, user_id): + """Indicate that user_id wishes their message history to be erased. + + Args: + user_id (str): full user_id to be erased + """ + def f(txn): + # first check if they are already in the list + txn.execute( + "SELECT 1 FROM erased_users WHERE user_id = ?", + (user_id, ) + ) + if txn.fetchone(): + return + + # they are not already there: do the insert. + txn.execute( + "INSERT INTO erased_users (user_id) VALUES (?)", + (user_id, ) + ) + + self._invalidate_cache_and_stream( + txn, self.is_user_erased, (user_id,) + ) + return self.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 95031dc9ec..d6160d5e4d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import deque import contextlib import threading +from collections import deque class IdGenerator(object): diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 4f089bfb94..451e4fa441 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.types import StreamToken - import logging +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer, parse_string +from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -57,48 +57,33 @@ class PaginationConfig(object): @classmethod def from_request(cls, request, raise_invalid_params=True, default_limit=None): - def get_param(name, default=None): - lst = request.args.get(name, []) - if len(lst) > 1: - raise SynapseError( - 400, "%s must be specified only once" % (name,) - ) - elif len(lst) == 1: - return lst[0] - else: - return default - - direction = get_param("dir", 'f') - if direction not in ['f', 'b']: - raise SynapseError(400, "'dir' parameter is invalid.") - - from_tok = get_param("from") - to_tok = get_param("to") + direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b']) + + from_tok = parse_string(request, "from") + to_tok = parse_string(request, "to") try: if from_tok == "END": from_tok = None # For backwards compat. elif from_tok: from_tok = StreamToken.from_string(from_tok) - except: + except Exception: raise SynapseError(400, "'from' paramater is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) - except: + except Exception: raise SynapseError(400, "'to' paramater is invalid") - limit = get_param("limit", None) - if limit is not None and not limit.isdigit(): - raise SynapseError(400, "'limit' parameter must be an integer.") + limit = parse_integer(request, "limit", default=default_limit) - if limit is None: - limit = default_limit + if limit and limit < 0: + raise SynapseError(400, "Limit must be 0 or above") try: return PaginationConfig(from_tok, to_tok, direction, limit) - except: + except Exception: logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 91a59b0bae..e5220132a3 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -15,13 +15,12 @@ from twisted.internet import defer -from synapse.types import StreamToken - +from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.presence import PresenceEventSource +from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource -from synapse.handlers.receipts import ReceiptEventSource -from synapse.handlers.account_data import AccountDataEventSource +from synapse.types import StreamToken class EventSources(object): @@ -45,6 +44,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -65,6 +65,7 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) @@ -73,6 +74,7 @@ class EventSources(object): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=( @@ -93,5 +95,6 @@ class EventSources(object): push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, + groups_key=groups_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 111948540d..08f058f714 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -12,26 +12,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import string +from collections import namedtuple from synapse.api.errors import SynapseError -from collections import namedtuple - -Requester = namedtuple("Requester", [ +class Requester(namedtuple("Requester", [ "user", "access_token_id", "is_guest", "device_id", "app_service", -]) -""" -Represents the user making a request +])): + """ + Represents the user making a request -Attributes: - user (UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user -""" + Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + app_service (ApplicationService|None): the AS requesting on behalf of the user + """ + + def serialize(self): + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + } + + @staticmethod + def deserialize(store, input): + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store (DataStore): Used to convert AS ID to AS object + input (dict): A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + device_id=input["device_id"], + app_service=appservice, + ) def create_requester(user_id, access_token_id=None, is_guest=False, @@ -126,14 +165,10 @@ class DomainSpecificString( try: cls.from_string(s) return True - except: + except Exception: return False - __str__ = to_string - - @classmethod - def create(cls, localpart, domain,): - return cls(localpart=localpart, domain=domain) + __repr__ = to_string class UserID(DomainSpecificString): @@ -156,6 +191,43 @@ class EventID(DomainSpecificString): SIGIL = "$" +class GroupID(DomainSpecificString): + """Structure representing a group ID.""" + SIGIL = "+" + + @classmethod + def from_string(cls, s): + group_id = super(GroupID, cls).from_string(s) + if not group_id.localpart: + raise SynapseError( + 400, + "Group ID cannot be empty", + ) + + if contains_invalid_mxid_characters(group_id.localpart): + raise SynapseError( + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + ) + + return group_id + + +mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits) + + +def contains_invalid_mxid_characters(localpart): + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart (basestring): the localpart to be checked + + Returns: + bool: True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + class StreamToken( namedtuple("Token", ( "room_key", @@ -166,6 +238,7 @@ class StreamToken( "push_rules_key", "to_device_key", "device_list_key", + "groups_key", )) ): _SEPARATOR = "_" @@ -178,7 +251,7 @@ class StreamToken( # i.e. old token from before receipt_key keys.append("0") return cls(*keys) - except: + except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): @@ -204,6 +277,7 @@ class StreamToken( or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.device_list_key) < int(self.device_list_key)) + or (int(other.groups_key) < int(self.groups_key)) ) def copy_and_advance(self, key, new_value): @@ -263,7 +337,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): if string[0] == 't': parts = string[1:].split('-', 1) return cls(topological=int(parts[0]), stream=int(parts[1])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -272,7 +346,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): try: if string[0] == 's': return cls(topological=None, stream=int(string[1:])) - except: + except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 2a2360ab5d..680ea928c7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,20 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError -from synapse.util.logcontext import PreserveLoggingContext - -from twisted.internet import defer, reactor, task - -import time import logging +from itertools import islice -logger = logging.getLogger(__name__) +import attr +from twisted.internet import defer, task -class DeferredTimedOutError(SynapseError): - def __init__(self): - super(DeferredTimedOutError, self).__init__(504, "Timed out") +from synapse.util.logcontext import PreserveLoggingContext + +logger = logging.getLogger(__name__) def unwrapFirstError(failure): @@ -35,16 +31,27 @@ def unwrapFirstError(failure): return failure.value.subFailure +@attr.s class Clock(object): - """A small utility that obtains current time-of-day so that time may be - mocked during unit-tests. + """ + A Clock wraps a Twisted reactor and provides utilities on top of it. - TODO(paul): Also move the sleep() functionality into it + Args: + reactor: The Twisted reactor to use. """ + _reactor = attr.ib() + + @defer.inlineCallbacks + def sleep(self, seconds): + d = defer.Deferred() + with PreserveLoggingContext(): + self._reactor.callLater(seconds, d.callback, seconds) + res = yield d + defer.returnValue(res) def time(self): """Returns the current system time in seconds since epoch.""" - return time.time() + return self._reactor.seconds() def time_msec(self): """Returns the current system time in miliseconds since epoch.""" @@ -59,9 +66,10 @@ class Clock(object): f(function): The function to call repeatedly. msec(float): How long to wait between calls in milliseconds. """ - l = task.LoopingCall(f) - l.start(msec / 1000.0, now=False) - return l + call = task.LoopingCall(f) + call.clock = self._reactor + call.start(msec / 1000.0, now=False) + return call def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -77,61 +85,27 @@ class Clock(object): callback(*args, **kwargs) with PreserveLoggingContext(): - return reactor.callLater(delay, wrapped_callback, *args, **kwargs) + return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer, ignore_errs=False): try: timer.cancel() - except: + except Exception: if not ignore_errs: raise - def time_bound_deferred(self, given_deferred, time_out): - if given_deferred.called: - return given_deferred - - ret_deferred = defer.Deferred() - def timed_out_fn(): - e = DeferredTimedOutError() +def batch_iter(iterable, size): + """batch an iterable up into tuples with a maximum size - try: - ret_deferred.errback(e) - except: - pass + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size - try: - given_deferred.cancel() - except: - pass - - timer = None - - def cancel(res): - try: - self.cancel_call_later(timer) - except: - pass - return res - - ret_deferred.addBoth(cancel) - - def success(res): - try: - ret_deferred.callback(res) - except: - pass - - return res - - def err(res): - try: - ret_deferred.errback(res) - except: - pass - - given_deferred.addCallbacks(callback=success, errback=err) - - timer = self.call_later(time_out, timed_out_fn) - - return ret_deferred + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) diff --git a/synapse/util/async.py b/synapse/util/async.py index 1453faf0ef..a7094e2fb4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,38 +13,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections +import logging +from contextlib import contextmanager + +from six.moves import range +from twisted.internet import defer +from twisted.internet.defer import CancelledError +from twisted.python import failure -from twisted.internet import defer, reactor +from synapse.util import Clock, logcontext, unwrapFirstError from .logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, ) -from synapse.util import unwrapFirstError - -from contextlib import contextmanager - -import logging logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def sleep(seconds): - d = defer.Deferred() - with PreserveLoggingContext(): - reactor.callLater(seconds, d.callback, seconds) - res = yield d - defer.returnValue(res) - - -def run_on_reactor(): - """ This will cause the rest of the function to be invoked upon the next - iteration of the main loop - """ - return sleep(0) - - class ObservableDeferred(object): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original @@ -53,6 +43,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] @@ -68,7 +63,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().callback(r) - except: + except Exception: pass return r @@ -78,7 +73,7 @@ class ObservableDeferred(object): try: # TODO: Handle errors here. self._observers.pop().errback(f) - except: + except Exception: pass if consumeErrors: @@ -151,77 +146,19 @@ def concurrently_execute(func, args, limit): def _concurrently_execute_inner(): try: while True: - yield func(it.next()) + yield func(next(it)) except StopIteration: pass - return preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(_concurrently_execute_inner)() - for _ in xrange(limit) + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError) class Linearizer(object): - """Linearizes access to resources based on a key. Useful to ensure only one - thing is happening at a time on a given resource. - - Example: - - with (yield linearizer.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None): - if name is None: - self.name = id(self) - else: - self.name = name - self.key_to_defer = {} - - @defer.inlineCallbacks - def queue(self, key): - # If there is already a deferred in the queue, we pull it out so that - # we can wait on it later. - # Then we replace it with a deferred that we resolve *after* the - # context manager has exited. - # We only return the context manager after the previous deferred has - # resolved. - # This all has the net effect of creating a chain of deferreds that - # wait for the previous deferred before starting their work. - current_defer = self.key_to_defer.get(key) - - new_defer = defer.Deferred() - self.key_to_defer[key] = new_defer - - if current_defer: - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key - ) - try: - with PreserveLoggingContext(): - yield current_defer - except: - logger.exception("Unexpected exception in Linearizer") - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - new_defer.callback(None) - current_d = self.key_to_defer.get(key) - if current_d is new_defer: - self.key_to_defer.pop(key, None) - - defer.returnValue(_ctx_manager()) - - -class Limiter(object): """Limits concurrent access to resources based on a key. Useful to ensure - only a few thing happen at a time on a given resource. + only a few things happen at a time on a given resource. Example: @@ -229,22 +166,31 @@ class Limiter(object): # do some work. """ - def __init__(self, max_count): + def __init__(self, name=None, max_count=1, clock=None): """ Args: - max_count(int): The maximum number of concurrent access + max_count(int): The maximum number of concurrent accesses """ + if name is None: + self.name = id(self) + else: + self.name = name + + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock self.max_count = max_count # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing - # the second element is a list of deferreds for the things blocked from - # executing. + # the first element is the number of things executing, and + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. self.key_to_defer = {} @defer.inlineCallbacks def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, []]) + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items @@ -252,27 +198,71 @@ class Limiter(object): # this item so that it can continue executing. if entry[0] >= self.max_count: new_defer = defer.Deferred() - entry[1].append(new_defer) - with PreserveLoggingContext(): - yield new_defer + entry[1][new_defer] = 1 + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise + + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + yield self._clock.sleep(0) - entry[0] += 1 + else: + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) + entry[0] += 1 @contextmanager def _ctx_manager(): try: yield finally: + logger.info("Releasing linearizer lock %r for key %r", self.name, key) + # We've finished executing so check if there are any things # blocked waiting to execute and start one of them entry[0] -= 1 - try: - entry[1].pop(0).callback(None) - except IndexError: - # If nothing else is executing for this key then remove it - # from the map - if entry[0] == 0: - self.key_to_defer.pop(key, None) + + if entry[1]: + (next_def, _) = entry[1].popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] defer.returnValue(_ctx_manager()) @@ -316,7 +306,7 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield curr_writer + yield make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -345,7 +335,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): @@ -357,3 +347,69 @@ class ReadWriteLock(object): self.key_to_current_writer.pop(key) defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 4adae96681..7b065b195e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,28 +13,87 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.metrics import os +import six +from six.moves import intern + +from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily + CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) -metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +def get_cache_factor_for(cache_name): + env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() + factor = os.environ.get(env_var) + if factor: + return float(factor) + + return CACHE_SIZE_FACTOR + caches_by_name = {} -# cache_counter = metrics.register_cache( -# "cache", -# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, -# labels=["name"], -# ) - - -def register_cache(name, cache): - caches_by_name[name] = cache - return metrics.register_cache( - "cache", - lambda: len(cache), - name, - ) +collectors_by_name = {} + +cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) +cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) +cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) +cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) + +response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) +response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) +response_cache_evicted = Gauge( + "synapse_util_caches_response_cache:evicted_size", "", ["name"] +) +response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) + + +def register_cache(cache_type, cache_name, cache): + + # Check if the metric is already registered. Unregister it, if so. + # This usually happens during tests, as at runtime these caches are + # effectively singletons. + metric_name = "cache_%s_%s" % (cache_type, cache_name) + if metric_name in collectors_by_name.keys(): + REGISTRY.unregister(collectors_by_name[metric_name]) + + class CacheMetric(object): + + hits = 0 + misses = 0 + evicted_size = 0 + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + if cache_type == "response_cache": + response_cache_size.labels(cache_name).set(len(cache)) + response_cache_hits.labels(cache_name).set(self.hits) + response_cache_evicted.labels(cache_name).set(self.evicted_size) + response_cache_total.labels(cache_name).set(self.hits + self.misses) + else: + cache_size.labels(cache_name).set(len(cache)) + cache_hits.labels(cache_name).set(self.hits) + cache_evicted.labels(cache_name).set(self.evicted_size) + cache_total.labels(cache_name).set(self.hits + self.misses) + + yield GaugeMetricFamily("__unused", "") + + metric = CacheMetric() + REGISTRY.register(metric) + caches_by_name[cache_name] = cache + collectors_by_name[metric_name] = metric + return metric KNOWN_KEYS = { @@ -66,7 +125,9 @@ def intern_string(string): return None try: - string = string.encode("ascii") + if six.PY2: + string = string.encode("ascii") + return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index af65bfe7b8..f8a07df6b8 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,25 +13,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import inspect import logging +import threading +from collections import namedtuple + +import six +from six import itervalues, string_types + +from twisted.internet import defer +from synapse.util import logcontext, unwrapFirstError from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError, logcontext -from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.stringutils import to_ascii from . import register_cache -from twisted.internet import defer -from collections import namedtuple - -import functools -import inspect -import threading - - logger = logging.getLogger(__name__) @@ -39,12 +41,11 @@ _CacheSentinel = object() class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -62,7 +63,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -75,13 +75,16 @@ class Cache(object): self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, ) self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None - self.metrics = register_cache(name, self.cache) + self.metrics = register_cache("cache", name, self.cache) + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) def check_thread(self): expected_thread = self.thread @@ -109,11 +112,10 @@ class Cache(object): callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: @@ -133,12 +135,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -146,13 +145,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, result, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -164,25 +175,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -190,8 +205,10 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in itervalues(self._pending_deferred_cache): + entry.invalidate() + self._pending_deferred_cache.clear() class _CacheDescriptorBase(object): @@ -294,7 +311,7 @@ class CacheDescriptor(_CacheDescriptorBase): orig, num_args=num_args, inlineCallbacks=inlineCallbacks, cache_context=cache_context) - max_entries = int(max_entries * CACHE_SIZE_FACTOR) + max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) self.max_entries = max_entries self.tree = tree @@ -376,9 +393,10 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - # If our cache_key is a string, try to convert to ascii to save - # a bit of space in large caches - if isinstance(cache_key, basestring): + # If our cache_key is a string on py2, try to convert to ascii + # to save a bit of space in large caches. Py3 does this + # internally automatically. + if six.PY2 and isinstance(cache_key, string_types): cache_key = to_ascii(cache_key) result_d = ObservableDeferred(ret, consumeErrors=True) @@ -549,7 +567,7 @@ class CacheListDescriptor(_CacheDescriptorBase): return results return logcontext.make_deferred_yieldable(defer.gatherResults( - cached_defers.values(), + list(cached_defers.values()), consumeErrors=True, ).addCallback(update_results_dict).addErrback( unwrapFirstError diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index d4105822b3..6c0b5a4094 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches.lrucache import LruCache -from collections import namedtuple -from . import register_cache -import threading import logging +import threading +from collections import namedtuple +from synapse.util.caches.lrucache import LruCache + +from . import register_cache logger = logging.getLogger(__name__) @@ -55,7 +56,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - self.metrics = register_cache(name, self.cache) + self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -107,34 +108,37 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False, known_absent=None): + def update(self, sequence, key, value, fetched_keys=None): """Updates the entry in the cache Args: sequence - key - value (dict): The value to update the cache with. - full (bool): Whether the given value is the full dict, or just a - partial subset there of. If not full then any existing entries - for the key will be updated. - known_absent (set): Set of keys that we know don't exist in the full - dict. + key (K) + value (dict[X,Y]): The value to update the cache with. + fetched_keys (None|set[X]): All of the dictionary keys which were + fetched from the database. + + If None, this is the complete value for key K. Otherwise, it + is used to infer a list of keys which we know don't exist in + the full dict. """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - if known_absent is None: - known_absent = set() - if full: - self._insert(key, value, known_absent) + if fetched_keys is None: + self._insert(key, value, set()) else: - self._update_or_insert(key, value, known_absent) + self._update_or_insert(key, value, fetched_keys) def _update_or_insert(self, key, value, known_absent): - entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) + # We pop and reinsert as we need to tell the cache the size may have + # changed + + entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) + self.cache[key] = entry def _insert(self, key, value, known_absent): self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 6ad53a6390..465adc54a8 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache - -from collections import OrderedDict import logging +from collections import OrderedDict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.caches import register_cache logger = logging.getLogger(__name__) @@ -52,19 +52,22 @@ class ExpiringCache(object): self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self) - self.iterable = iterable self._size_estimate = 0 + self.metrics = register_cache("expiring", cache_name, self) + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire return def f(): - self._prune_cache() + run_as_background_process( + "prune_cache_%s" % self._cache_name, + self._prune_cache, + ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -79,7 +82,11 @@ class ExpiringCache(object): while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - self._size_estimate -= len(value.value) + removed_len = len(value.value) + self.metrics.inc_evictions(removed_len) + self._size_estimate -= removed_len + else: + self.metrics.inc_evictions() def __getitem__(self, key): try: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index cf5fbb679c..b684f24e7b 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -14,8 +14,8 @@ # limitations under the License. -from functools import wraps import threading +from functools import wraps from synapse.util.caches.treecache import TreeCache @@ -49,7 +49,24 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, + evicted_callback=None): + """ + Args: + max_size (int): + + keylen (int): + + cache_type (type): + type of underlying cache to be used. Typically one of dict + or TreeCache. + + size_callback (func(V) -> int | None): + + evicted_callback (func(int)|None): + if not None, called on eviction with the size of the evicted + entry + """ cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -61,8 +78,10 @@ class LruCache(object): def evict(): while cache_len() > max_size: todelete = list_root.prev_node - delete_node(todelete) + evicted_len = delete_node(todelete) cache.pop(todelete.key, None) + if evicted_callback: + evicted_callback(evicted_len) def synchronized(f): @wraps(f) @@ -111,12 +130,15 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + deleted_len = 1 if size_callback: - cached_cache_len[0] -= size_callback(node.value) + deleted_len = size_callback(node.value) + cached_cache_len[0] -= deleted_len for cb in node.callbacks: cb() node.callbacks.clear() + return deleted_len @synchronized def cache_get(key, default=None, callbacks=[]): @@ -132,14 +154,21 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: - if value != node.value: + # We sometimes store large objects, e.g. dicts, which cause + # the inequality check to take a long time. So let's only do + # the check if we have some callbacks to call. + if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() - if size_callback: - cached_cache_len[0] -= size_callback(node.value) - cached_cache_len[0] += size_callback(value) + # We don't bother to protect this by value != node.value as + # generally size_callback will be cheap compared with equality + # checks. (For example, taking the size of two dicts is quicker + # than comparing them for equality.) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) node.callbacks.update(callbacks) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 00af539880..a8491b42d5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,8 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from twisted.internet import defer from synapse.util.async import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + +logger = logging.getLogger(__name__) class ResponseCache(object): @@ -24,20 +31,69 @@ class ResponseCache(object): used rather than trying to compute a new response. """ - def __init__(self, hs, timeout_ms=0): + def __init__(self, hs, name, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000. + self._name = name + self._metrics = register_cache( + "response_cache", name, self + ) + + def size(self): + return len(self.pending_result_cache) + + def __len__(self): + return self.size() + def get(self, key): + """Look up the given key. + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if the request has completed, the actual + result. You will probably want to make_deferred_yieldable the result. + + If there is no entry for the key, returns None. It is worth noting that + this means there is no way to distinguish a completed result of None + from an absent cache entry. + + Args: + key (hashable): + + Returns: + twisted.internet.defer.Deferred|None|E: None if there is no entry + for this key; otherwise either a deferred result or the result + itself. + """ result = self.pending_result_cache.get(key) if result is not None: + self._metrics.inc_hits() return result.observe() else: + self._metrics.inc_misses() return None def set(self, key, deferred): + """Set the entry for the given key to the given deferred. + + *deferred* should run its callbacks in the sentinel logcontext (ie, + you should wrap normal synapse deferreds with + logcontext.run_in_background). + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if *deferred* was already complete, the actual + result. You will probably want to make_deferred_yieldable the result. + + Args: + key (hashable): + deferred (twisted.internet.defer.Deferred[T): + + Returns: + twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual + result. + """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -53,3 +109,52 @@ class ResponseCache(object): result.addBoth(remove) return result.observe() + + def wrap(self, key, callback, *args, **kwargs): + """Wrap together a *get* and *set* call, taking care of logcontexts + + First looks up the key in the cache, and if it is present makes it + follow the synapse logcontext rules and returns it. + + Otherwise, makes a call to *callback(*args, **kwargs)*, which should + follow the synapse logcontext rules, and adds the result to the cache. + + Example usage: + + @defer.inlineCallbacks + def handle_request(request): + # etc + defer.returnValue(result) + + result = yield response_cache.wrap( + key, + handle_request, + request, + ) + + Args: + key (hashable): key to get/set in the cache + + callback (callable): function to call if the key is not found in + the cache + + *args: positional parameters to pass to the callback, if it is used + + **kwargs: named paramters to pass to the callback, if it is used + + Returns: + twisted.internet.defer.Deferred: yieldable result + """ + result = self.get(key) + if not result: + logger.info("[%s]: no cached result for [%s], calculating new one", + self._name, key) + d = run_in_background(callback, *args, **kwargs) + result = self.set(key, d) + elif not isinstance(result, defer.Deferred) or result.called: + logger.info("[%s]: using completed cached result for [%s]", + self._name, key) + else: + logger.info("[%s]: using incomplete cached result for [%s]", + self._name, key) + return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 941d873ab8..f2bde74dc5 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR - - -from blist import sorteddict import logging +from sortedcontainers import SortedDict + +from synapse.util import caches logger = logging.getLogger(__name__) @@ -32,16 +31,18 @@ class StreamChangeCache(object): entities that may have changed since that position. If position key is too old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): - self._max_size = int(max_size * CACHE_SIZE_FACTOR) + + def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): + self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) self._entity_to_key = {} - self._cache = sorteddict() + self._cache = SortedDict() self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = register_cache(self.name, self._cache) + self.metrics = caches.register_cache("cache", self.name, self._cache) - for entity, stream_pos in prefilled_cache.items(): - self.entity_has_changed(entity, stream_pos) + if prefilled_cache: + for entity, stream_pos in prefilled_cache.items(): + self.entity_has_changed(entity, stream_pos) def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos @@ -65,22 +66,25 @@ class StreamChangeCache(object): return False def get_entities_changed(self, entities, stream_pos): - """Returns subset of entities that have had new things since the - given position. If the position is too old it will just return the given list. + """ + Returns subset of entities that have had new things since the given + position. Entities unknown to the cache will be returned. If the + position is too old it will just return the given list. """ assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) + changed_entities = { + self._cache[k] for k in self._cache.islice( + start=self._cache.bisect_right(stream_pos), + ) + } - result = set( - self._cache[k] for k in keys[i:] - ).intersection(entities) + result = changed_entities.intersection(entities) self.metrics.inc_hits() else: - result = entities + result = set(entities) self.metrics.inc_misses() return result @@ -90,12 +94,13 @@ class StreamChangeCache(object): """ assert type(stream_pos) is int + if not self._cache: + # If we have no cache, nothing can have changed. + return False + if stream_pos >= self._earliest_known_stream_pos: self.metrics.inc_hits() - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) - - return i < len(keys) + return self._cache.bisect_right(stream_pos) < len(self._cache) else: self.metrics.inc_misses() return True @@ -107,10 +112,8 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - keys = self._cache.keys() - i = keys.bisect_right(stream_pos) - - return [self._cache[k] for k in keys[i:]] + return [self._cache[k] for k in self._cache.islice( + start=self._cache.bisect_right(stream_pos))] else: return None @@ -129,8 +132,10 @@ class StreamChangeCache(object): self._entity_to_key[entity] = stream_pos while len(self._cache) > self._max_size: - k, r = self._cache.popitem() - self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max( + k, self._earliest_known_stream_pos, + ) self._entity_to_key.pop(r, None) def get_max_pos_of_last_change(self, entity): diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index fcc341a6b7..dd4c9e6067 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,3 +1,5 @@ +from six import itervalues + SENTINEL = object() @@ -49,7 +51,7 @@ class TreeCache(object): if popped is SENTINEL: return default - node_and_keys = zip(nodes, key) + node_and_keys = list(zip(nodes, key)) node_and_keys.reverse() node_and_keys.append((self.root, None)) @@ -76,7 +78,7 @@ def iterate_tree_cache_entry(d): can contain dicts. """ if isinstance(d, dict): - for value_d in d.itervalues(): + for value_d in itervalues(d): for value in iterate_tree_cache_entry(value_d): yield value else: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e68f94ce77..194da87639 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,32 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_fn -) - -from synapse.util import unwrapFirstError - import logging +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) + distributor.fire("user_left_room", user=user, room_id=room_id) def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) + distributor.fire("user_joined_room", user=user, room_id=room_id) class Distributor(object): @@ -52,9 +42,7 @@ class Distributor(object): model will do for today. """ - def __init__(self, suppress_failures=True): - self.suppress_failures = suppress_failures - + def __init__(self): self.signals = {} self.pre_registration = {} @@ -64,7 +52,6 @@ class Distributor(object): self.signals[name] = Signal( name, - suppress_failures=self.suppress_failures, ) if name in self.pre_registration: @@ -83,10 +70,18 @@ class Distributor(object): self.pre_registration[name].append(observer) def fire(self, name, *args, **kwargs): + """Dispatches the given signal to the registered observers. + + Runs the observers as a background process. Does not return a deferred. + """ if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - return self.signals[name].fire(*args, **kwargs) + run_as_background_process( + name, + self.signals[name].fire, + *args, **kwargs + ) class Signal(object): @@ -99,9 +94,8 @@ class Signal(object): method into all of the observers. """ - def __init__(self, name, suppress_failures): + def __init__(self, name): self.name = name - self.suppress_failures = suppress_failures self.observers = [] def observe(self, observer): @@ -111,7 +105,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -129,22 +122,17 @@ class Signal(object): failure.type, failure.value, failure.getTracebackObject())) - if not self.suppress_failures: - return failure return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - with PreserveLoggingContext(): - deferreds = [ - do(observer) - for observer in self.observers - ] - - res = yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + deferreds = [ + run_in_background(do, o) + for o in self.observers + ] - defer.returnValue(res) + return make_deferred_yieldable(defer.gatherResults( + deferreds, consumeErrors=True, + )) def __repr__(self): return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py new file mode 100644 index 0000000000..629ed44149 --- /dev/null +++ b/synapse/util/file_consumer.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from six.moves import queue + +from twisted.internet import threads + +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + + +class BackgroundFileConsumer(object): + """A consumer that writes to a file like object. Supports both push + and pull producers + + Args: + file_obj (file): The file like object to write to. Closed when + finished. + reactor (twisted.internet.reactor): the Twisted reactor to use + """ + + # For PushProducers pause if we have this many unwritten slices + _PAUSE_ON_QUEUE_SIZE = 5 + # And resume once the size of the queue is less than this + _RESUME_ON_QUEUE_SIZE = 2 + + def __init__(self, file_obj, reactor): + self._file_obj = file_obj + + self._reactor = reactor + + # Producer we're registered with + self._producer = None + + # True if PushProducer, false if PullProducer + self.streaming = False + + # For PushProducers, indicates whether we've paused the producer and + # need to call resumeProducing before we get more data. + self._paused_producer = False + + # Queue of slices of bytes to be written. When producer calls + # unregister a final None is sent. + self._bytes_queue = queue.Queue() + + # Deferred that is resolved when finished writing + self._finished_deferred = None + + # If the _writer thread throws an exception it gets stored here. + self._write_exception = None + + def registerProducer(self, producer, streaming): + """Part of IConsumer interface + + Args: + producer (IProducer) + streaming (bool): True if push based producer, False if pull + based. + """ + if self._producer: + raise Exception("registerProducer called twice") + + self._producer = producer + self.streaming = streaming + self._finished_deferred = run_in_background( + threads.deferToThreadPool, + self._reactor, + self._reactor.getThreadPool(), + self._writer, + ) + if not streaming: + self._producer.resumeProducing() + + def unregisterProducer(self): + """Part of IProducer interface + """ + self._producer = None + if not self._finished_deferred.called: + self._bytes_queue.put_nowait(None) + + def write(self, bytes): + """Part of IProducer interface + """ + if self._write_exception: + raise self._write_exception + + if self._finished_deferred.called: + raise Exception("consumer has closed") + + self._bytes_queue.put_nowait(bytes) + + # If this is a PushProducer and the queue is getting behind + # then we pause the producer. + if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: + self._paused_producer = True + self._producer.pauseProducing() + + def _writer(self): + """This is run in a background thread to write to the file. + """ + try: + while self._producer or not self._bytes_queue.empty(): + # If we've paused the producer check if we should resume the + # producer. + if self._producer and self._paused_producer: + if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: + self._reactor.callFromThread(self._resume_paused_producer) + + bytes = self._bytes_queue.get() + + # If we get a None (or empty list) then that's a signal used + # to indicate we should check if we should stop. + if bytes: + self._file_obj.write(bytes) + + # If its a pull producer then we need to explicitly ask for + # more stuff. + if not self.streaming and self._producer: + self._reactor.callFromThread(self._producer.resumeProducing) + except Exception as e: + self._write_exception = e + raise + finally: + self._file_obj.close() + + def wait(self): + """Returns a deferred that resolves when finished writing to file + """ + return make_deferred_yieldable(self._finished_deferred) + + def _resume_paused_producer(self): + """Gets called if we should resume producing after being paused + """ + if self._paused_producer and self._producer: + self._paused_producer = False + self._producer.resumeProducing() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 6322f0f55c..581c6052ac 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,18 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types + +from canonicaljson import json from frozendict import frozendict def freeze(o): - t = type(o) - if t is dict: + if isinstance(o, dict): return frozendict({k: freeze(v) for k, v in o.items()}) - if t is frozendict: + if isinstance(o, frozendict): return o - if t is str or t is unicode: + if isinstance(o, string_types): return o try: @@ -36,11 +38,10 @@ def freeze(o): def unfreeze(o): - t = type(o) - if t is dict or t is frozendict: + if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if t is str or t is unicode: + if isinstance(o, string_types): return o try: @@ -49,3 +50,21 @@ def unfreeze(o): pass return o + + +def _handle_frozendict(obj): + """Helper for EventEncoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError('Object of type %s is not JSON serializable' % + obj.__class__.__name__) + + +# A JSONEncoder which is capable of encoding frozendics without barfing +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, +) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 45be47159a..2d7ddc1cbe 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource - import logging +from twisted.web.resource import NoResource + logger = logging.getLogger(__name__) @@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource): # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} for full_path, res in desired_tree.items(): + # twisted requires all resources to be bytes + full_path = full_path.encode("utf-8") + logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split('/')[1:-1]: + for path_seg in full_path.split(b'/')[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = Resource() + child_resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split('/')[-1] + last_path_seg = full_path.split(b'/')[-1] # if there is already a resource here, thieve its children and # replace it diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 990216145e..8dcae50b39 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -22,10 +22,10 @@ them. See doc/log_contexts.rst for details on how this works. """ -from twisted.internet import defer - -import threading import logging +import threading + +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -42,23 +42,128 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) -except: +except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. def get_thread_resource_usage(): return None +class ContextResourceUsage(object): + """Object for tracking the resources used by a log context + + Attributes: + ru_utime (float): user CPU time (in seconds) + ru_stime (float): system CPU time (in seconds) + db_txn_count (int): number of database transactions done + db_sched_duration_sec (float): amount of time spent waiting for a + database connection + db_txn_duration_sec (float): amount of time spent doing database + transactions (excluding scheduling time) + evt_db_fetch_count (int): number of events requested from the database + """ + + __slots__ = [ + "ru_stime", "ru_utime", + "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "evt_db_fetch_count", + ] + + def __init__(self, copy_from=None): + """Create a new ContextResourceUsage + + Args: + copy_from (ContextResourceUsage|None): if not None, an object to + copy stats from + """ + if copy_from is None: + self.reset() + else: + self.ru_utime = copy_from.ru_utime + self.ru_stime = copy_from.ru_stime + self.db_txn_count = copy_from.db_txn_count + + self.db_txn_duration_sec = copy_from.db_txn_duration_sec + self.db_sched_duration_sec = copy_from.db_sched_duration_sec + self.evt_db_fetch_count = copy_from.evt_db_fetch_count + + def copy(self): + return ContextResourceUsage(copy_from=self) + + def reset(self): + self.ru_stime = 0. + self.ru_utime = 0. + self.db_txn_count = 0 + + self.db_txn_duration_sec = 0 + self.db_sched_duration_sec = 0 + self.evt_db_fetch_count = 0 + + def __repr__(self): + return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', " + "db_txn_count='%r', db_txn_duration_sec='%r', " + "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count,) + + def __iadd__(self, other): + """Add another ContextResourceUsage's stats to this one's. + + Args: + other (ContextResourceUsage): the other resource usage object + """ + self.ru_utime += other.ru_utime + self.ru_stime += other.ru_stime + self.db_txn_count += other.db_txn_count + self.db_txn_duration_sec += other.db_txn_duration_sec + self.db_sched_duration_sec += other.db_sched_duration_sec + self.evt_db_fetch_count += other.evt_db_fetch_count + return self + + def __isub__(self, other): + self.ru_utime -= other.ru_utime + self.ru_stime -= other.ru_stime + self.db_txn_count -= other.db_txn_count + self.db_txn_duration_sec -= other.db_txn_duration_sec + self.db_sched_duration_sec -= other.db_sched_duration_sec + self.evt_db_fetch_count -= other.evt_db_fetch_count + return self + + def __add__(self, other): + res = ContextResourceUsage(copy_from=self) + res += other + return res + + def __sub__(self, other): + res = ContextResourceUsage(copy_from=self) + res -= other + return res + + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. + + If a parent is given when creating a new context, then: + - logging fields are copied from the parent to the new context on entry + - when the new context exits, the cpu usage stats are copied from the + child to the parent + Args: name (str): Name for the context for debugging. + parent_context (LoggingContext|None): The parent of the new context """ __slots__ = [ - "previous_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", "alive", + "previous_context", "name", "parent_context", + "_resource_usage", + "usage_start", + "main_thread", "alive", + "request", "tag", ] thread_local = threading.local() @@ -80,32 +185,49 @@ class LoggingContext(object): def stop(self): pass - def add_database_transaction(self, duration_ms): + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): pass def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() - def __init__(self, name=None): + def __init__(self, name=None, parent_context=None): self.previous_context = LoggingContext.current_context() self.name = name - self.ru_stime = 0. - self.ru_utime = 0. - self.db_txn_count = 0 - self.db_txn_duration = 0. + + # track the resources used by this context so far + self._resource_usage = ContextResourceUsage() + + # If alive has the thread resource usage when the logcontext last + # became active. self.usage_start = None + self.main_thread = threading.current_thread() + self.request = None self.tag = "" self.alive = True + self.parent_context = parent_context + def __str__(self): return "%s@%x" % (self.name, id(self)) @classmethod def current_context(cls): - """Get the current logging context from thread local storage""" + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ return getattr(cls.thread_local, "current_context", cls.sentinel) @classmethod @@ -133,18 +255,22 @@ class LoggingContext(object): self.previous_context, old_context ) self.alive = True + + if self.parent_context is not None: + self.parent_context.copy_to(self) + return self def __exit__(self, type, value, traceback): """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: - None to avoid suppressing any exeptions that were thrown. + None to avoid suppressing any exceptions that were thrown. """ current = self.set_current_context(self.previous_context) if current is not self: if current is self.sentinel: - logger.debug("Expected logging context %s has been lost", self) + logger.warn("Expected logging context %s has been lost", self) else: logger.warn( "Current logging context %s is not expected context %s", @@ -154,47 +280,91 @@ class LoggingContext(object): self.previous_context = None self.alive = False + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None: + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + def copy_to(self, record): - """Copy fields from this context to the record""" - for key, value in self.__dict__.items(): - setattr(record, key, value) + """Copy logging fields from this context to a log record or + another LoggingContext + """ - record.ru_utime, record.ru_stime = self.get_resource_usage() + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request def start(self): if threading.current_thread() is not self.main_thread: + logger.warning("Started logcontext %s on different thread", self) return - if self.usage_start and self.usage_end: - self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime - self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime - self.usage_start = None - self.usage_end = None - + # If we haven't already started record the thread resource usage so + # far if not self.usage_start: self.usage_start = get_thread_resource_usage() def stop(self): if threading.current_thread() is not self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + # When we stop, let's record the cpu used since we started + if not self.usage_start: + logger.warning( + "Called stop on logcontext %s without calling start", self, + ) return - if self.usage_start: - self.usage_end = get_thread_resource_usage() + usage_end = get_thread_resource_usage() + + self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime + self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime + + self.usage_start = None def get_resource_usage(self): - ru_utime = self.ru_utime - ru_stime = self.ru_stime + """Get resources used by this logcontext so far. - if self.usage_start and threading.current_thread() is self.main_thread: + Returns: + ContextResourceUsage: a *copy* of the object tracking resource + usage so far + """ + # we always return a copy, for consistency + res = self._resource_usage.copy() + + # If we are on the correct thread and we're currently running then we + # can include resource usage so far. + is_main_thread = threading.current_thread() is self.main_thread + if self.alive and self.usage_start and is_main_thread: current = get_thread_resource_usage() - ru_utime += current.ru_utime - self.usage_start.ru_utime - ru_stime += current.ru_stime - self.usage_start.ru_stime + res.ru_utime += current.ru_utime - self.usage_start.ru_utime + res.ru_stime += current.ru_stime - self.usage_start.ru_stime - return ru_utime, ru_stime + return res - def add_database_transaction(self, duration_ms): - self.db_txn_count += 1 - self.db_txn_duration += duration_ms / 1000. + def add_database_transaction(self, duration_sec): + self._resource_usage.db_txn_count += 1 + self._resource_usage.db_txn_duration_sec += duration_sec + + def add_database_scheduled(self, sched_sec): + """Record a use of the database pool + + Args: + sched_sec (float): number of seconds it took us to get a + connection + """ + self._resource_usage.db_sched_duration_sec += sched_sec + + def record_event_fetch(self, event_count): + """Record a number of events being fetched from the db + + Args: + event_count (int): number of events being fetched + """ + self._resource_usage.evt_db_fetch_count += event_count class LoggingContextFilter(logging.Filter): @@ -248,7 +418,7 @@ class PreserveLoggingContext(object): context = LoggingContext.set_current_context(self.current_context) if context != self.new_context: - logger.debug( + logger.warn( "Unexpected logging context: %s is not %s", context, self.new_context, ) @@ -261,105 +431,62 @@ class PreserveLoggingContext(object): ) -class _PreservingContextDeferred(defer.Deferred): - """A deferred that ensures that all callbacks and errbacks are called with - the given logging context. - """ - def __init__(self, context): - self._log_context = context - defer.Deferred.__init__(self) - - def addCallbacks(self, callback, errback=None, - callbackArgs=None, callbackKeywords=None, - errbackArgs=None, errbackKeywords=None): - callback = self._wrap_callback(callback) - errback = self._wrap_callback(errback) - return defer.Deferred.addCallbacks( - self, callback, - errback=errback, - callbackArgs=callbackArgs, - callbackKeywords=callbackKeywords, - errbackArgs=errbackArgs, - errbackKeywords=errbackKeywords, - ) +def preserve_fn(f): + """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + return g - def _wrap_callback(self, f): - def g(res, *args, **kwargs): - with PreserveLoggingContext(self._log_context): - res = f(res, *args, **kwargs) - return res - return g +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. -def preserve_context_over_fn(fn, *args, **kwargs): - """Takes a function and invokes it with the given arguments, but removes - and restores the current logging context while doing so. + Useful for wrapping functions that return a deferred which you don't yield + on (for instance because you want to pass it to deferred.gatherResults()). - If the result is a deferred, call preserve_context_over_deferred before - returning it. + Note that if you completely discard the result, you should make sure that + `f` doesn't raise any deferred exceptions, otherwise a scary-looking + CRITICAL error about an unhandled error will be logged without much + indication about where it came from. """ - with PreserveLoggingContext(): - res = fn(*args, **kwargs) + current = LoggingContext.current_context() + try: + res = f(*args, **kwargs) + except: # noqa: E722 + # the assumption here is that the caller doesn't want to be disturbed + # by synchronous exceptions, so let's turn them into Failures. + return defer.fail() - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred(res) - else: + if not isinstance(res, defer.Deferred): return res - -def preserve_context_over_deferred(deferred, context=None): - """Given a deferred wrap it such that any callbacks added later to it will - be invoked with the current context. - - Deprecated: this almost certainly doesn't do want you want, ie make - the deferred follow the synapse logcontext rules: try - ``make_deferred_yieldable`` instead. - """ - if context is None: - context = LoggingContext.current_context() - d = _PreservingContextDeferred(context) - deferred.chainDeferred(d) - return d - - -def preserve_fn(f): - """Wraps a function, to ensure that the current context is restored after - return from the function, and that the sentinel context is set once the - deferred returned by the funtion completes. - - Useful for wrapping functions that return a deferred which you don't yield - on. - """ - def reset_context(result): - LoggingContext.set_current_context(LoggingContext.sentinel) - return result - - def g(*args, **kwargs): - current = LoggingContext.current_context() - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(reset_context) + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about return res - return g + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) + return res -@defer.inlineCallbacks def make_deferred_yieldable(deferred): """Given a deferred, make it follow the Synapse logcontext rules: @@ -371,11 +498,27 @@ def make_deferred_yieldable(deferred): returning a deferred. Then, when the deferred completes, restores the current logcontext before running callbacks/errbacks. - (This is more-or-less the opposite operation to preserve_fn.) + (This is more-or-less the opposite operation to run_in_background.) """ - with PreserveLoggingContext(): - r = yield deferred - defer.returnValue(r) + if not isinstance(deferred, defer.Deferred): + return deferred + + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred + + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred + + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result # modules to ignore in `logcontext_tracer` diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py new file mode 100644 index 0000000000..a46bc47ce3 --- /dev/null +++ b/synapse/util/logformatter.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import traceback + +from six import StringIO + + +class LogFormatter(logging.Formatter): + """Log formatter which gives more detail for exceptions + + This is the same as the standard log formatter, except that when logging + exceptions [typically via log.foo("msg", exc_info=1)], it prints the + sequence that led up to the point at which the exception was caught. + (Normally only stack frames between the point the exception was raised and + where it was caught are logged). + """ + def __init__(self, *args, **kwargs): + super(LogFormatter, self).__init__(*args, **kwargs) + + def formatException(self, ei): + sio = StringIO() + (typ, val, tb) = ei + + # log the stack above the exception capture point if possible, but + # check that we actually have an f_back attribute to work around + # https://twistedmatrix.com/trac/ticket/9305 + + if tb and hasattr(tb.tb_frame, 'f_back'): + sio.write("Capture point (most recent call last):\n") + traceback.print_stack(tb.tb_frame.f_back, None, sio) + + traceback.print_exception(typ, val, tb, None, sio) + s = sio.getvalue() + sio.close() + if s[-1:] == "\n": + s = s[:-1] + return s diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index 3a83828d25..62a00189cc 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -14,13 +14,11 @@ # limitations under the License. -from inspect import getcallargs -from functools import wraps - -import logging import inspect +import logging import time - +from functools import wraps +from inspect import getcallargs _TIME_FUNC_ID = 0 @@ -96,7 +94,7 @@ def time_function(f): id = _TIME_FUNC_ID _TIME_FUNC_ID += 1 - start = time.clock() * 1000 + start = time.clock() try: _log_debug_as_f( @@ -107,10 +105,10 @@ def time_function(f): r = f(*args, **kwargs) finally: - end = time.clock() * 1000 + end = time.clock() _log_debug_as_f( f, - "[FUNC END] {%s-%d} %f", + "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start,), ) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 97e0f00b67..14be3c7396 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.conch.manhole import ColoredManhole -from twisted.conch.insults import insults from twisted.conch import manhole_ssh -from twisted.cred import checkers, portal +from twisted.conch.insults import insults +from twisted.conch.manhole import ColoredManhole from twisted.conch.ssh.keys import Key +from twisted.cred import checkers, portal PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4ea930d3e8..97f1267380 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,40 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging +from functools import wraps -from synapse.util.logcontext import LoggingContext -import synapse.metrics +from prometheus_client import Counter -from functools import wraps -import logging +from twisted.internet import defer +from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) +block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) -metrics = synapse.metrics.get_metrics_for(__name__) +block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) -block_timer = metrics.register_distribution( - "block_timer", - labels=["block_name"] -) +block_ru_utime = Counter( + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) -block_ru_utime = metrics.register_distribution( - "block_ru_utime", labels=["block_name"] -) +block_ru_stime = Counter( + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) -block_ru_stime = metrics.register_distribution( - "block_ru_stime", labels=["block_name"] -) +block_db_txn_count = Counter( + "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) -block_db_txn_count = metrics.register_distribution( - "block_db_txn_count", labels=["block_name"] -) +# seconds spent waiting for db txns, excluding scheduling time, in this block +block_db_txn_duration = Counter( + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) -block_db_txn_duration = metrics.register_distribution( - "block_db_txn_duration", labels=["block_name"] -) +# seconds spent waiting for a db connection, in this block +block_db_sched_duration = Counter( + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) def measure_func(name): @@ -63,8 +60,9 @@ def measure_func(name): class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration", "created_context" + "clock", "name", "start_context", "start", + "created_context", + "start_usage", ] def __init__(self, clock, name): @@ -75,23 +73,23 @@ class Measure(object): self.created_context = False def __enter__(self): - self.start = self.clock.time_msec() + self.start = self.clock.time() self.start_context = LoggingContext.current_context() if not self.start_context: self.start_context = LoggingContext("Measure") self.start_context.__enter__() self.created_context = True - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + self.start_usage = self.start_context.get_resource_usage() def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: return - duration = self.clock.time_msec() - self.start - block_timer.inc_by(duration, self.name) + duration = self.clock.time() - self.start + + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) context = LoggingContext.current_context() @@ -106,16 +104,19 @@ class Measure(object): logger.warn("Expected context. (%r)", self.name) return - ru_utime, ru_stime = context.get_resource_usage() - - block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) - block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by( - context.db_txn_count - self.db_txn_count, self.name - ) - block_db_txn_duration.inc_by( - context.db_txn_duration - self.db_txn_duration, self.name - ) + current = context.get_resource_usage() + usage = current - self.start_usage + try: + block_ru_utime.labels(self.name).inc(usage.ru_utime) + block_ru_stime.labels(self.name).inc(usage.ru_stime) + block_db_txn_count.labels(self.name).inc(usage.db_txn_count) + block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) + block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) + except ValueError: + logger.warn( + "Failed to save metrics! OLD: %r, NEW: %r", + self.start_usage, current + ) if self.created_context: self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py new file mode 100644 index 0000000000..4288312b8a --- /dev/null +++ b/synapse/util/module_loader.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +from synapse.config._base import ConfigError + + +def load_module(provider): + """ Loads a module with its config + Take a dict with keys 'module' (the module name) and 'config' + (the config dict). + + Returns + Tuple of (provider class, parsed config object) + """ + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) + + return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index 607161e7f0..a6c30e5265 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -14,6 +14,7 @@ # limitations under the License. import phonenumbers + from synapse.api.errors import SynapseError diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 1101881a2d..7deb38f2a7 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -13,17 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.api.errors import LimitExceededError - -from synapse.util.async import sleep -from synapse.util.logcontext import preserve_fn - import collections import contextlib import logging +from twisted.internet import defer + +from synapse.api.errors import LimitExceededError +from synapse.util.logcontext import ( + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) logger = logging.getLogger(__name__) @@ -91,13 +92,22 @@ class _PerHostRatelimiter(object): self.window_size = window_size self.sleep_limit = sleep_limit - self.sleep_msec = sleep_msec + self.sleep_sec = sleep_msec / 1000.0 self.reject_limit = reject_limit self.concurrent_requests = concurrent_requests + # request_id objects for requests which have been slept self.sleeping_requests = set() + + # map from request_id object to Deferred for requests which are ready + # for processing but have been queued self.ready_request_queue = collections.OrderedDict() + + # request id objects for requests which are in progress self.current_processing = set() + + # times at which we have recently (within the last window_size ms) + # received requests. self.request_times = [] @contextlib.contextmanager @@ -116,11 +126,15 @@ class _PerHostRatelimiter(object): def _on_enter(self, request_id): time_now = self.clock.time_msec() + + # remove any entries from request_times which aren't within the window self.request_times[:] = [ r for r in self.request_times if time_now - r < self.window_size ] + # reject the request if we already have too many queued up (either + # sleeping or in the ready queue). queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( @@ -133,9 +147,13 @@ class _PerHostRatelimiter(object): def queue_request(): if len(self.current_processing) > self.concurrent_requests: - logger.debug("Ratelimit [%s]: Queue req", id(request_id)) queue_defer = defer.Deferred() self.ready_request_queue[request_id] = queue_defer + logger.info( + "Ratelimiter: queueing request (queue now %i items)", + len(self.ready_request_queue), + ) + return queue_defer else: return defer.succeed(None) @@ -147,10 +165,9 @@ class _PerHostRatelimiter(object): if len(self.request_times) > self.sleep_limit: logger.debug( - "Ratelimit [%s]: sleeping req", - id(request_id), + "Ratelimiter: sleeping request for %f sec", self.sleep_sec, ) - ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) + ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) @@ -176,6 +193,9 @@ class _PerHostRatelimiter(object): return r def on_err(r): + # XXX: why is this necessary? this is called before we start + # processing the request so why would the request be in + # current_processing? self.current_processing.discard(request_id) return r @@ -187,7 +207,7 @@ class _PerHostRatelimiter(object): ret_defer.addCallbacks(on_start, on_err) ret_defer.addBoth(on_both) - return ret_defer + return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): logger.debug( @@ -196,8 +216,10 @@ class _PerHostRatelimiter(object): ) self.current_processing.discard(request_id) try: - request_id, deferred = self.ready_request_queue.popitem() - self.current_processing.add(request_id) - deferred.callback(None) + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) except KeyError: pass diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 4fa9d1a03c..8a3a06fd74 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -12,20 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.logcontext -from twisted.internet import defer - -from synapse.api.errors import CodeMessageException - import logging import random +from twisted.internet import defer + +import synapse.util.logcontext +from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) class NotRetryingDestination(Exception): def __init__(self, retry_last_ts, retry_interval, destination): + """Raised by the limiter (and federation client) to indicate that we are + are deliberately not attempting to contact a given server. + + Args: + retry_last_ts (int): the unix ts in milliseconds of our last attempt + to contact the server. 0 indicates that the last attempt was + successful or that we've never actually attempted to connect. + retry_interval (int): the time in milliseconds to wait until the next + attempt. + destination (str): the domain in question + """ + msg = "Not retrying server %s." % (destination,) super(NotRetryingDestination, self).__init__(msg) @@ -189,10 +200,10 @@ class RetryDestinationLimiter(object): yield self.store.set_destination_retry_timings( self.destination, retry_last_ts, self.retry_interval ) - except: + except Exception: logger.exception( - "Failed to store set_destination_retry_timings", + "Failed to store destination_retry_timings", ) # we deliberately do this in the background. - synapse.util.logcontext.preserve_fn(store_retry_timings)() + synapse.util.logcontext.run_in_background(store_retry_timings) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index f4a9abf83f..6c0f2bb0cf 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import resource import logging - +import resource logger = logging.getLogger("synapse.app.homeserver") diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 95a6168e16..43d9db67ec 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,18 +16,20 @@ import random import string +from six.moves import range + _string_with_symbols = ( string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" ) def random_string(length): - return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) + return ''.join(random.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): return ''.join( - random.choice(_string_with_symbols) for _ in xrange(length) + random.choice(_string_with_symbols) for _ in range(length) ) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py new file mode 100644 index 0000000000..75efa0117b --- /dev/null +++ b/synapse/util/threepids.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +logger = logging.getLogger(__name__) + + +def check_3pid_allowed(hs, medium, address): + """Checks whether a given format of 3PID is allowed to be used on this HS + + Args: + hs (synapse.server.HomeServer): server + medium (str): 3pid medium - e.g. email, msisdn + address (str): address within that medium (e.g. "wotan@matrix.org") + msisdns need to first have been canonicalised + Returns: + bool: whether the 3PID medium/address is allowed to be added to this HS + """ + + if hs.config.allowed_local_3pids: + for constraint in hs.config.allowed_local_3pids: + logger.debug( + "Checking 3PID %s (%s) against %s (%s)", + address, medium, constraint['pattern'], constraint['medium'], + ) + if ( + medium == constraint['medium'] and + re.match(constraint['pattern'], address) + ): + return True + else: + return True + + return False diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 52086df465..1fbcd41115 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import os import logging +import os +import subprocess logger = logging.getLogger(__name__) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7412fc57a4..7a9e45aca9 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import range + class _Entry(object): __slots__ = ["end_key", "queue"] @@ -68,7 +70,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. self.entries.extend( - _Entry(key) for key in xrange(last_key, then_key + 1) + _Entry(key) for key in range(last_key, then_key + 1) ) self.entries[-1].queue.append(obj) @@ -91,7 +93,4 @@ class WheelTimer(object): return ret def __len__(self): - l = 0 - for entry in self.entries: - l += len(entry.queue) - return l + return sum(len(entry.queue) for entry in self.entries) diff --git a/synapse/visibility.py b/synapse/visibility.py index c4dd9ae2c7..d4680863d3 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.api.constants import Membership, EventTypes +import logging +import operator -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from six import iteritems, itervalues +from six.moves import map -import logging +from twisted.internet import defer +from synapse.api.constants import EventTypes, Membership +from synapse.events.utils import prune_event +from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -43,53 +46,66 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks -def filter_events_for_clients(store, user_tuples, events, event_id_to_state): - """ Returns dict of user_id -> list of events that user is allowed to - see. +def filter_events_for_client(store, user_id, events, is_peeking=False, + always_include_ids=frozenset()): + """ + Check which events a user is allowed to see Args: - user_tuples (str, bool): (user id, is_peeking) for each user to be - checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the - given events - events ([synapse.events.EventBase]): list of events to filter - """ - forgotten = yield preserve_context_over_deferred(defer.gatherResults([ - defer.maybeDeferred( - preserve_fn(store.who_forgot_in_room), - room_id, - ) - for room_id in frozenset(e.room_id for e in events) - ], consumeErrors=True)) + store (synapse.storage.DataStore): our datastore (can also be a worker + store) + user_id(str): user id to be checked + events(list[synapse.events.EventBase]): sequence of events to be checked + is_peeking(bool): should be True if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the given + events + always_include_ids (set(event_id)): set of event ids to specifically + include (unless sender is ignored) - # Set of membership event_ids that have been forgotten - event_id_forgotten = frozenset( - row["event_id"] for rows in forgotten for row in rows + Returns: + Deferred[list[synapse.events.EventBase]] + """ + types = ( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + event_id_to_state = yield store.get_state_for_events( + frozenset(e.event_id for e in events), + types=types, ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_users( - "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples] + ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + "m.ignored_user_list", user_id, ) # FIXME: This will explode if people upload something incorrect. - ignore_dict = { - user_id: frozenset( - content.get("ignored_users", {}).keys() if content else [] - ) - for user_id, content in ignore_dict_content.items() - } + ignore_list = frozenset( + ignore_dict_content.get("ignored_users", {}).keys() + if ignore_dict_content else [] + ) + + erased_senders = yield store.are_users_erased((e.sender for e in events)) - def allowed(event, user_id, is_peeking, ignore_list): + def allowed(event): """ Args: event (synapse.events.EventBase): event to check - user_id (str) - is_peeking (bool) - ignore_list (list): list of users to ignore + + Returns: + None|EventBase: + None if the user cannot see this event at all + + a redacted copy of the event if they can only see a redacted + version + + the original event if they can see it as normal. """ if not event.is_state() and event.sender in ignore_list: - return False + return None + + if event.event_id in always_include_ids: + return event state = event_id_to_state[event.event_id] @@ -103,10 +119,6 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): if visibility not in VISIBILITY_PRIORITY: visibility = "shared" - # if it was world_readable, it's easy: everyone can read it - if visibility == "world_readable": - return True - # Always allow history visibility events on boundaries. This is done # by setting the effective visibility to the least restrictive # of the old vs new. @@ -140,7 +152,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): if membership == "leave" and ( prev_membership == "join" or prev_membership == "invite" ): - return True + return event new_priority = MEMBERSHIP_PRIORITY.index(membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) @@ -151,87 +163,203 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): if membership is None: membership_event = state.get((EventTypes.Member, user_id), None) if membership_event: - if membership_event.event_id not in event_id_forgotten: - membership = membership_event.membership + membership = membership_event.membership # if the user was a member of the room at the time of the event, # they can see it. if membership == Membership.JOIN: - return True + return event + + # otherwise, it depends on the room visibility. if visibility == "joined": # we weren't a member at the time of the event, so we can't # see this event. - return False + return None elif visibility == "invited": # user can also see the event if they were *invited* at the time # of the event. - return membership == Membership.INVITE - - else: - # visibility is shared: user can also see the event if they have - # become a member since the event + return ( + event if membership == Membership.INVITE else None + ) + + elif visibility == "shared" and is_peeking: + # if the visibility is shared, users cannot see the event unless + # they have *subequently* joined the room (or were members at the + # time, of course) # # XXX: if the user has subsequently joined and then left again, # ideally we would share history up to the point they left. But - # we don't know when they left. - return not is_peeking + # we don't know when they left. We just treat it as though they + # never joined, and restrict access. + return None - defer.returnValue({ - user_id: [ - event - for event in events - if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, [])) - ] - for user_id, is_peeking in user_tuples - }) + # the visibility is either shared or world_readable, and the user was + # not a member at the time. We allow it, provided the original sender + # has not requested their data to be erased, in which case, we return + # a redacted version. + if erased_senders[event.sender]: + return prune_event(event) + return event -@defer.inlineCallbacks -def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): - user_ids = set(u[0] for u in user_tuples) - event_id_to_state = {} - for event_id, context in event_id_to_context.items(): - state = yield store.get_events([ - e_id - for key, e_id in context.current_state_ids.iteritems() - if key == (EventTypes.RoomHistoryVisibility, "") - or (key[0] == EventTypes.Member and key[1] in user_ids) - ]) - event_id_to_state[event_id] = state - - res = yield filter_events_for_clients( - store, user_tuples, events, event_id_to_state - ) - defer.returnValue(res) + # check each event: gives an iterable[None|EventBase] + filtered_events = map(allowed, events) + # remove the None entries + filtered_events = filter(operator.truth, filtered_events) -@defer.inlineCallbacks -def filter_events_for_client(store, user_id, events, is_peeking=False): - """ - Check which events a user is allowed to see + # we turn it into a list before returning it. + defer.returnValue(list(filtered_events)) - Args: - user_id(str): user id to be checked - events([synapse.events.EventBase]): list of events to be checked - is_peeking(bool): should be True if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events - Returns: - [synapse.events.EventBase] - """ - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), +@defer.inlineCallbacks +def filter_events_for_server(store, server_name, events): + # Whatever else we do, we need to check for senders which have requested + # erasure of their data. + erased_senders = yield store.are_users_erased( + e.sender for e in events, ) - event_id_to_state = yield store.get_state_for_events( + + def redact_disallowed(event, state): + # if the sender has been gdpr17ed, always return a redacted + # copy of the event. + if erased_senders[event.sender]: + logger.info( + "Sender of %s has been erased, redacting", + event.event_id, + ) + return prune_event(event) + + # state will be None if we decided we didn't need to filter by + # room membership. + if not state: + return event + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + if visibility in ["invited", "joined"]: + # We now loop through all state events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in itervalues(state): + if ev.type != EventTypes.Member: + continue + try: + domain = get_domain_from_id(ev.state_key) + except Exception: + continue + + if domain != server_name: + continue + + memtype = ev.membership + if memtype == Membership.JOIN: + return event + elif memtype == Membership.INVITE: + if visibility == "invited": + return event + else: + # server has no users in the room: redact + return prune_event(event) + + return event + + # Next lets check to see if all the events have a history visibility + # of "shared" or "world_readable". If thats the case then we don't + # need to check membership (as we know the server is in the room). + event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=types + types=( + (EventTypes.RoomHistoryVisibility, ""), + ) ) - res = yield filter_events_for_clients( - store, [(user_id, is_peeking)], events, event_id_to_state + + visibility_ids = set() + for sids in itervalues(event_to_state_ids): + hist = sids.get((EventTypes.RoomHistoryVisibility, "")) + if hist: + visibility_ids.add(hist) + + # If we failed to find any history visibility events then the default + # is "shared" visiblity. + if not visibility_ids: + all_open = True + else: + event_map = yield store.get_events(visibility_ids) + all_open = all( + e.content.get("history_visibility") in (None, "shared", "world_readable") + for e in itervalues(event_map) + ) + + if all_open: + # all the history_visibility state affecting these events is open, so + # we don't need to filter by membership state. We *do* need to check + # for user erasure, though. + if erased_senders: + events = [ + redact_disallowed(e, None) + for e in events + ] + + defer.returnValue(events) + + # Ok, so we're dealing with events that have non-trivial visibility + # rules, so we need to also get the memberships of the room. + + # first, for each event we're wanting to return, get the event_ids + # of the history vis and membership state at those events. + event_to_state_ids = yield store.get_state_ids_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) ) - defer.returnValue(res.get(user_id, [])) + + # We only want to pull out member events that correspond to the + # server's domain. + # + # event_to_state_ids contains lots of duplicates, so it turns out to be + # cheaper to build a complete set of unique + # ((type, state_key), event_id) tuples, and then filter out the ones we + # don't want. + # + state_key_to_event_id_set = { + e + for key_to_eid in itervalues(event_to_state_ids) + for e in key_to_eid.items() + } + + def include(typ, state_key): + if typ != EventTypes.Member: + return True + + # we avoid using get_domain_from_id here for efficiency. + idx = state_key.find(":") + if idx == -1: + return False + return state_key[idx + 1:] == server_name + + event_map = yield store.get_events([ + e_id + for key, e_id in state_key_to_event_id_set + if include(key[0], key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in iteritems(key_to_eid) + if inner_e_id in event_map + } + for e_id, key_to_eid in iteritems(event_to_state_ids) + } + + defer.returnValue([ + redact_disallowed(e, event_to_state[e.event_id]) + for e in events + ]) |