diff options
Diffstat (limited to 'synapse')
167 files changed, 6911 insertions, 2810 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index e62901b761..89ea9a9775 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,4 +17,14 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.3" +try: + from twisted.internet import protocol + from twisted.internet.protocol import Factory + from twisted.names.dns import DNSDatagramProtocol + protocol.Factory.noisy = False + Factory.noisy = False + DNSDatagramProtocol.noisy = False +except ImportError: + pass + +__version__ = "0.33.8" diff --git a/synapse/_scripts/__init__.py b/synapse/_scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/_scripts/__init__.py diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py new file mode 100644 index 0000000000..70cecde486 --- /dev/null +++ b/synapse/_scripts/register_new_matrix_user.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import getpass +import hashlib +import hmac +import logging +import sys + +from six.moves import input + +import requests as _requests +import yaml + + +def request_registration( + user, + password, + server_location, + shared_secret, + admin=False, + requests=_requests, + _print=print, + exit=sys.exit, +): + + url = "%s/_matrix/client/r0/admin/register" % (server_location,) + + # Get the nonce + r = requests.get(url, verify=False) + + if r.status_code is not 200: + _print("ERROR! Received %d %s" % (r.status_code, r.reason)) + if 400 <= r.status_code < 500: + try: + _print(r.json()["error"]) + except Exception: + pass + return exit(1) + + nonce = r.json()["nonce"] + + mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) + + mac.update(nonce.encode('utf8')) + mac.update(b"\x00") + mac.update(user.encode('utf8')) + mac.update(b"\x00") + mac.update(password.encode('utf8')) + mac.update(b"\x00") + mac.update(b"admin" if admin else b"notadmin") + + mac = mac.hexdigest() + + data = { + "nonce": nonce, + "username": user, + "password": password, + "mac": mac, + "admin": admin, + } + + _print("Sending registration request...") + r = requests.post(url, json=data, verify=False) + + if r.status_code is not 200: + _print("ERROR! Received %d %s" % (r.status_code, r.reason)) + if 400 <= r.status_code < 500: + try: + _print(r.json()["error"]) + except Exception: + pass + return exit(1) + + _print("Success!") + + +def register_new_user(user, password, server_location, shared_secret, admin): + if not user: + try: + default_user = getpass.getuser() + except Exception: + default_user = None + + if default_user: + user = input("New user localpart [%s]: " % (default_user,)) + if not user: + user = default_user + else: + user = input("New user localpart: ") + + if not user: + print("Invalid user name") + sys.exit(1) + + if not password: + password = getpass.getpass("Password: ") + + if not password: + print("Password cannot be blank.") + sys.exit(1) + + confirm_password = getpass.getpass("Confirm password: ") + + if password != confirm_password: + print("Passwords do not match") + sys.exit(1) + + if admin is None: + admin = input("Make admin [no]: ") + if admin in ("y", "yes", "true"): + admin = True + else: + admin = False + + request_registration(user, password, server_location, shared_secret, bool(admin)) + + +def main(): + + logging.captureWarnings(True) + + parser = argparse.ArgumentParser( + description="Used to register new users with a given home server when" + " registration has been disabled. The home server must be" + " configured with the 'registration_shared_secret' option" + " set." + ) + parser.add_argument( + "-u", + "--user", + default=None, + help="Local part of the new user. Will prompt if omitted.", + ) + parser.add_argument( + "-p", + "--password", + default=None, + help="New password for user. Will prompt if omitted.", + ) + admin_group = parser.add_mutually_exclusive_group() + admin_group.add_argument( + "-a", + "--admin", + action="store_true", + help=( + "Register new user as an admin. " + "Will prompt if --no-admin is not set either." + ), + ) + admin_group.add_argument( + "--no-admin", + action="store_true", + help=( + "Register new user as a regular user. " + "Will prompt if --admin is not set either." + ), + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "-c", + "--config", + type=argparse.FileType('r'), + help="Path to server config file. Used to read in shared secret.", + ) + + group.add_argument( + "-k", "--shared-secret", help="Shared secret as defined in server config file." + ) + + parser.add_argument( + "server_url", + default="https://localhost:8448", + nargs='?', + help="URL to use to talk to the home server. Defaults to " + " 'https://localhost:8448'.", + ) + + args = parser.parse_args() + + if "config" in args and args.config: + config = yaml.safe_load(args.config) + secret = config.get("registration_shared_secret", None) + if not secret: + print("No 'registration_shared_secret' defined in config.") + sys.exit(1) + else: + secret = args.shared_secret + + admin = None + if args.admin or args.no_admin: + admin = args.admin + + register_new_user(args.user, args.password, args.server_url, secret, admin) + + +if __name__ == "__main__": + main() diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6502a6be7b..34382e4e3c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,6 +26,7 @@ import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -775,34 +776,56 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None): + def check_auth_blocking(self, user_id=None, threepid=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: user_id(str|None): If present, checks for presence against existing MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. """ + + # Never fail an auth check for the server notices users + # This can be a problem where event creation is prohibited due to blocking + if user_id == self.hs.config.server_notices_mxid: + return + if self.hs.config.hs_disabled: raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEED, - admin_uri=self.hs.config.admin_uri, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self.hs.config.admin_contact, limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: - # If the user is already part of the MAU cohort + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user if user_id: timestamp = yield self.store.user_last_seen_monthly_active(user_id) if timestamp: return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self.hs.config, threepid): + return # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( 403, "Monthly Active User Limit Exceeded", - - admin_uri=self.hs.config.admin_uri, - errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_contact=self.hs.config.admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, limit_type="monthly_active_user" ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 9d244059fc..9ad1939ed5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class EventTypes(object): Name = "m.room.name" ServerACL = "m.room.server_acl" + Pinned = "m.room.pinned_events" class RejectedReason(object): @@ -98,9 +99,17 @@ class ThirdPartyEntityKind(object): LOCATION = "location" +class RoomVersions(object): + V1 = "1" + VDH_TEST = "vdh-test-version" + + # the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = "1" +DEFAULT_ROOM_VERSION = RoomVersions.V1 # vdh-test-version is a placeholder to get room versioning support working and tested # until we have a working v2. -KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} +KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST} + +ServerNoticeMsgType = "m.server_notice" +ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e26001ab12..48b903374d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -56,9 +56,10 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" - RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED" + RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED" UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" + WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" class CodeMessageException(RuntimeError): @@ -238,11 +239,11 @@ class ResourceLimitError(SynapseError): """ def __init__( self, code, msg, - errcode=Codes.RESOURCE_LIMIT_EXCEED, - admin_uri=None, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=None, limit_type=None, ): - self.admin_uri = admin_uri + self.admin_contact = admin_contact self.limit_type = limit_type super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) @@ -250,7 +251,7 @@ class ResourceLimitError(SynapseError): return cs_error( self.msg, self.errcode, - admin_uri=self.admin_uri, + admin_contact=self.admin_contact, limit_type=self.limit_type ) @@ -312,6 +313,20 @@ class LimitExceededError(SynapseError): ) +class RoomKeysVersionError(SynapseError): + """A client has tried to upload to a non-current version of the room_keys store + """ + def __init__(self, current_version): + """ + Args: + current_version (str): the current version of the store they should have used + """ + super(RoomKeysVersionError, self).__init__( + 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION + ) + self.current_version = current_version + + class IncompatibleRoomVersionError(SynapseError): """A server is trying to join a room whose version it does not support.""" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 186831e118..677c0bdd4c 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -172,7 +172,10 @@ USER_FILTER_SCHEMA = { # events a lot easier as we can then use a negative lookbehind # assertion to split '\.' If we allowed \\ then it would # incorrectly split '\\.' See synapse.events.utils.serialize_event - "pattern": "^((?!\\\).)*$" + # + # Note that because this is a regular expression, we have to escape + # each backslash in the pattern. + "pattern": r"^((?!\\\\).)*$" } } }, @@ -226,7 +229,7 @@ class Filtering(object): jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()) except jsonschema.ValidationError as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) class FilterCollection(object): @@ -251,6 +254,7 @@ class FilterCollection(object): "include_leave", False ) self.event_fields = filter_json.get("event_fields", []) + self.event_format = filter_json.get("event_format", "client") def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 71347912f1..6d9f1ca0ef 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -64,7 +64,7 @@ class ConsentURIBuilder(object): """ mac = hmac.new( key=self._hmac_secret, - msg=user_id, + msg=user_id.encode('ascii'), digestmod=sha256, ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 3b6b9368b8..c3afcc573b 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -24,7 +24,7 @@ try: python_dependencies.check_requirements() except python_dependencies.MissingRequirementError as e: message = "\n".join([ - "Missing Requirement: %s" % (e.message,), + "Missing Requirement: %s" % (str(e),), "To install run:", " pip install --upgrade --force \"%s\"" % (e.dependency,), "", diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 7c866e246a..18584226e9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -17,6 +17,7 @@ import gc import logging import sys +import psutil from daemonize import Daemonize from twisted.internet import error, reactor @@ -24,12 +25,6 @@ from twisted.internet import error, reactor from synapse.util import PreserveLoggingContext from synapse.util.rlimit import change_resource_limit -try: - import affinity -except Exception: - affinity = None - - logger = logging.getLogger(__name__) @@ -89,15 +84,20 @@ def start_reactor( with PreserveLoggingContext(): logger.info("Running") if cpu_affinity is not None: - if not affinity: - quit_with_error( - "Missing package 'affinity' required for cpu_affinity\n" - "option\n\n" - "Install by running:\n\n" - " pip install affinity\n\n" - ) - logger.info("Setting CPU affinity to %s" % cpu_affinity) - affinity.set_process_affinity_mask(0, cpu_affinity) + # Turn the bitmask into bits, reverse it so we go from 0 up + mask_to_bits = bin(cpu_affinity)[2:][::-1] + + cpus = [] + cpu_num = 0 + + for i in mask_to_bits: + if i == "1": + cpus.append(cpu_num) + cpu_num += 1 + + p = psutil.Process() + p.cpu_affinity(cpus) + change_resource_limit(soft_file_limit) if gc_thresholds: gc.set_threshold(*gc_thresholds) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 3348a8ec6d..8559e141af 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -51,10 +51,7 @@ class AppserviceSlaveStore( class AppserviceServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = AppserviceSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -139,7 +136,7 @@ def start(config_options): "Synapse appservice", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.appservice" @@ -175,7 +172,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ab79a45646..76aed8c60a 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -74,10 +74,7 @@ class ClientReaderSlavedStore( class ClientReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = ClientReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -156,7 +153,7 @@ def start(config_options): "Synapse client reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.client_reader" @@ -184,7 +181,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 03d39968a8..e4a68715aa 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) from synapse.rest.client.v1.room import ( JoinRoomAliasServlet, RoomMembershipRestServlet, @@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.engines import create_engine +from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole @@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator") class EventCreatorSlavedStore( + # FIXME(#3714): We need to add UserDirectoryStore as we write directly + # rather than going via the correct worker. + UserDirectoryStore, DirectoryStore, SlavedTransactionStore, SlavedProfileStore, @@ -81,10 +90,7 @@ class EventCreatorSlavedStore( class EventCreatorServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = EventCreatorSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -101,6 +107,9 @@ class EventCreatorServer(HomeServer): RoomMembershipRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, @@ -160,7 +169,7 @@ def start(config_options): "Synapse event creator", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.event_creator" @@ -169,6 +178,9 @@ def start(config_options): setup_logging(config, use_worker_options=True) + # This should only be done on the user directory worker or the master + config.update_user_directory = False + events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -190,7 +202,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7d8105778d..228a297fb8 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -72,10 +72,7 @@ class FederationReaderSlavedStore( class FederationReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -143,7 +140,7 @@ def start(config_options): "Synapse federation reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_reader" @@ -171,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index d59007099b..e9a99d76e1 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -78,10 +78,7 @@ class FederationSenderSlaveStore( class FederationSenderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationSenderSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -163,7 +160,7 @@ def start(config_options): "Synapse federation sender", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_sender" @@ -204,7 +201,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) _base.start_worker_reactor("synapse-federation-sender", config) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 8d484c1cd4..f5c61dec5b 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -68,7 +68,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet): "Authorization": auth_headers, } result = yield self.http_client.get_json( - self.main_uri + request.uri, + self.main_uri + request.uri.decode('ascii'), headers=headers, ) defer.returnValue((200, result)) @@ -125,7 +125,7 @@ class KeyUploadServlet(RestServlet): "Authorization": auth_headers, } result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri, + self.main_uri + request.uri.decode('ascii'), body, headers=headers, ) @@ -148,10 +148,7 @@ class FrontendProxySlavedStore( class FrontendProxyServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FrontendProxySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -231,7 +228,7 @@ def start(config_options): "Synapse frontend proxy", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.frontend_proxy" @@ -261,7 +258,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 005921dcf7..593e1e75db 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -20,6 +20,7 @@ import sys from six import iteritems +import psutil from prometheus_client import Gauge from twisted.application import service @@ -62,7 +63,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer -from synapse.storage import are_all_users_on_domain +from synapse.storage import DataStore, are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.util.caches import CACHE_SIZE_FACTOR @@ -111,6 +112,8 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + def _listener_http(self, config, listener_config): port = listener_config["port"] bind_addresses = listener_config["bind_addresses"] @@ -299,12 +302,16 @@ class SynapseHomeServer(HomeServer): try: database_engine.check_database(db_conn.cursor()) except IncorrectDatabaseSetup as e: - quit_with_error(e.message) + quit_with_error(str(e)) # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") +registered_reserved_users_mau_gauge = Gauge( + "synapse_admin_mau:registered_reserved_users", + "Registered users with reserved threepids" +) def setup(config_options): @@ -322,7 +329,7 @@ def setup(config_options): config_options, ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) if not config: @@ -356,13 +363,13 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn(run_new_connection=False) - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) + with hs.get_db_conn(run_new_connection=False) as db_conn: + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) - hs.run_startup_checks(db_conn, database_engine) + hs.run_startup_checks(db_conn, database_engine) - db_conn.commit() + db_conn.commit() except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" @@ -378,10 +385,8 @@ def setup(config_options): def start(): hs.get_pusherpool().start() - hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_federation_client().start_get_pdu_cache() reactor.callWhenRunning(start) @@ -451,6 +456,10 @@ def run(hs): stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) stats["total_users"] = yield hs.get_datastore().count_all_users() total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() @@ -494,7 +503,6 @@ def run(hs): def performance_stats_init(): try: - import psutil process = psutil.Process() # Ensure we can fetch both, and make the initial request for cpu_percent # so the next request will use this as the initial point. @@ -502,12 +510,9 @@ def run(hs): process.cpu_percent(interval=None) logger.info("report_stats can use psutil") stats_process.append(process) - except (ImportError, AttributeError): - logger.warn( - "report_stats enabled but psutil is not installed or incorrect version." - " Disabling reporting of memory/cpu stats." - " Ensuring psutil is available will help matrix.org track performance" - " changes across releases." + except (AttributeError): + logger.warning( + "Unable to read memory/cpu stats. Disabling reporting." ) def generate_user_daily_visit_stats(): @@ -522,25 +527,35 @@ def run(hs): clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) # monthly active user limiting functionality - clock.looping_call( - hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 - ) - hs.get_datastore().reap_monthly_active_users() + def reap_monthly_active_users(): + return run_as_background_process( + "reap_monthly_active_users", + hs.get_datastore().reap_monthly_active_users, + ) + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) + reap_monthly_active_users() @defer.inlineCallbacks def generate_monthly_active_users(): - count = 0 + current_mau_count = 0 + reserved_count = 0 + store = hs.get_datastore() if hs.config.limit_usage_by_mau: - count = yield hs.get_datastore().get_monthly_active_count() - current_mau_gauge.set(float(count)) + current_mau_count = yield store.get_monthly_active_count() + reserved_count = yield store.get_registered_reserved_users_count() + current_mau_gauge.set(float(current_mau_count)) + registered_reserved_users_mau_gauge.set(float(reserved_count)) max_mau_gauge.set(float(hs.config.max_mau_value)) - hs.get_datastore().initialise_reserved_users( - hs.config.mau_limits_reserved_threepids - ) - generate_monthly_active_users() + def start_generate_monthly_active_users(): + return run_as_background_process( + "generate_monthly_active_users", + generate_monthly_active_users, + ) + + start_generate_monthly_active_users() if hs.config.limit_usage_by_mau: - clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings if hs.config.report_stats: @@ -556,7 +571,7 @@ def run(hs): clock.call_later(5 * 60, start_phone_stats_home) if hs.config.daemonize and hs.config.print_pidfile: - print (hs.config.pid_file) + print(hs.config.pid_file) _base.start_reactor( "synapse-homeserver", diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index fd1f6cbf7e..acc0487adc 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -60,10 +60,7 @@ class MediaRepositorySlavedStore( class MediaRepositoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = MediaRepositorySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -136,7 +133,7 @@ def start(config_options): "Synapse media repository", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.media_repository" @@ -171,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a4fc7e91fa..83b0863f00 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -28,6 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.pushers import SlavedPusherStore @@ -49,39 +50,36 @@ class PusherSlaveStore( SlavedAccountDataStore ): update_pusher_last_stream_ordering_and_success = ( - DataStore.update_pusher_last_stream_ordering_and_success.__func__ + __func__(DataStore.update_pusher_last_stream_ordering_and_success) ) update_pusher_failing_since = ( - DataStore.update_pusher_failing_since.__func__ + __func__(DataStore.update_pusher_failing_since) ) update_pusher_last_stream_ordering = ( - DataStore.update_pusher_last_stream_ordering.__func__ + __func__(DataStore.update_pusher_last_stream_ordering) ) get_throttle_params_by_room = ( - DataStore.get_throttle_params_by_room.__func__ + __func__(DataStore.get_throttle_params_by_room) ) set_throttle_params = ( - DataStore.set_throttle_params.__func__ + __func__(DataStore.set_throttle_params) ) get_time_of_last_push_action_before = ( - DataStore.get_time_of_last_push_action_before.__func__ + __func__(DataStore.get_time_of_last_push_action_before) ) get_profile_displayname = ( - DataStore.get_profile_displayname.__func__ + __func__(DataStore.get_profile_displayname) ) class PusherServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = PusherSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = PusherSlaveStore def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) @@ -163,11 +161,11 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - self.pusher_pool.on_new_notifications( + yield self.pusher_pool.on_new_notifications( token, token, ) elif stream_name == "receipts": - self.pusher_pool.on_new_receipts( + yield self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) ) except Exception: @@ -185,7 +183,7 @@ class PusherReplicationHandler(ReplicationClientHandler): def start_pusher(self, user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) + return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) def start(config_options): @@ -194,7 +192,7 @@ def start(config_options): "Synapse pusher", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.pusher" @@ -231,7 +229,6 @@ def start(config_options): def start(): ps.get_pusherpool().start() ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 27e1998660..3926c7f263 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -33,7 +33,7 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource -from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -147,7 +147,7 @@ class SynchrotronPresence(object): and haven't come back yet. If there are poke the master about them. """ now = self.clock.time_msec() - for user_id, last_sync_ms in self.users_going_offline.items(): + for user_id, last_sync_ms in list(self.users_going_offline.items()): if now - last_sync_ms > 10 * 1000: self.users_going_offline.pop(user_id, None) self.send_user_sync(user_id, False, last_sync_ms) @@ -156,9 +156,9 @@ class SynchrotronPresence(object): # TODO Hows this supposed to work? pass - get_states = PresenceHandler.get_states.__func__ - get_state = PresenceHandler.get_state.__func__ - current_state_for_users = PresenceHandler.current_state_for_users.__func__ + get_states = __func__(PresenceHandler.get_states) + get_state = __func__(PresenceHandler.get_state) + current_state_for_users = __func__(PresenceHandler.current_state_for_users) def user_syncing(self, user_id, affect_presence): if affect_presence: @@ -208,7 +208,7 @@ class SynchrotronPresence(object): ) for row in rows] for state in states: - self.user_to_current_state[row.user_id] = state + self.user_to_current_state[state.user_id] = state stream_id = token yield self.notify_from_replication(states, stream_id) @@ -249,10 +249,7 @@ class SynchrotronApplicationService(object): class SynchrotronServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = SynchrotronSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -413,7 +410,7 @@ def start(config_options): "Synapse synchrotron", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.synchrotron" @@ -438,7 +435,6 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() - ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py deleted file mode 100755 index d658f967ba..0000000000 --- a/synapse/app/synctl.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import collections -import errno -import glob -import os -import os.path -import signal -import subprocess -import sys -import time - -from six import iteritems - -import yaml - -SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] - -GREEN = "\x1b[1;32m" -YELLOW = "\x1b[1;33m" -RED = "\x1b[1;31m" -NORMAL = "\x1b[m" - - -def pid_running(pid): - try: - os.kill(pid, 0) - return True - except OSError as err: - if err.errno == errno.EPERM: - return True - return False - - -def write(message, colour=NORMAL, stream=sys.stdout): - if colour == NORMAL: - stream.write(message + "\n") - else: - stream.write(colour + message + NORMAL + "\n") - - -def abort(message, colour=RED, stream=sys.stderr): - write(message, colour, stream) - sys.exit(1) - - -def start(configfile): - write("Starting ...") - args = SYNAPSE - args.extend(["--daemonize", "-c", configfile]) - - try: - subprocess.check_call(args) - write("started synapse.app.homeserver(%r)" % - (configfile,), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting (exit code: %d); see above for logs" % e.returncode, - colour=RED, - ) - - -def start_worker(app, configfile, worker_configfile): - args = [ - "python", "-B", - "-m", app, - "-c", configfile, - "-c", worker_configfile - ] - - try: - subprocess.check_call(args) - write("started %s(%r)" % (app, worker_configfile), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting %s(%r) (exit code: %d); see above for logs" % ( - app, worker_configfile, e.returncode, - ), - colour=RED, - ) - - -def stop(pidfile, app): - if os.path.exists(pidfile): - pid = int(open(pidfile).read()) - try: - os.kill(pid, signal.SIGTERM) - write("stopped %s" % (app,), colour=GREEN) - except OSError as err: - if err.errno == errno.ESRCH: - write("%s not running" % (app,), colour=YELLOW) - elif err.errno == errno.EPERM: - abort("Cannot stop %s: Operation not permitted" % (app,)) - else: - abort("Cannot stop %s: Unknown error" % (app,)) - - -Worker = collections.namedtuple("Worker", [ - "app", "configfile", "pidfile", "cache_factor" -]) - - -def main(): - - parser = argparse.ArgumentParser() - - parser.add_argument( - "action", - choices=["start", "stop", "restart"], - help="whether to start, stop or restart the synapse", - ) - parser.add_argument( - "configfile", - nargs="?", - default="homeserver.yaml", - help="the homeserver config file, defaults to homeserver.yaml", - ) - parser.add_argument( - "-w", "--worker", - metavar="WORKERCONFIG", - help="start or stop a single worker", - ) - parser.add_argument( - "-a", "--all-processes", - metavar="WORKERCONFIGDIR", - help="start or stop all the workers in the given directory" - " and the main synapse process", - ) - - options = parser.parse_args() - - if options.worker and options.all_processes: - write( - 'Cannot use "--worker" with "--all-processes"', - stream=sys.stderr - ) - sys.exit(1) - - configfile = options.configfile - - if not os.path.exists(configfile): - write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), options.configfile - ), - stream=sys.stderr, - ) - sys.exit(1) - - with open(configfile) as stream: - config = yaml.load(stream) - - pidfile = config["pid_file"] - cache_factor = config.get("synctl_cache_factor") - start_stop_synapse = True - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - - cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): - os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) - - worker_configfiles = [] - if options.worker: - start_stop_synapse = False - worker_configfile = options.worker - if not os.path.exists(worker_configfile): - write( - "No worker config found at %r" % (worker_configfile,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.append(worker_configfile) - - if options.all_processes: - # To start the main synapse with -a you need to add a worker file - # with worker_app == "synapse.app.homeserver" - start_stop_synapse = False - worker_configdir = options.all_processes - if not os.path.isdir(worker_configdir): - write( - "No worker config directory found at %r" % (worker_configdir,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.extend(sorted(glob.glob( - os.path.join(worker_configdir, "*.yaml") - ))) - - workers = [] - for worker_configfile in worker_configfiles: - with open(worker_configfile) as stream: - worker_config = yaml.load(stream) - worker_app = worker_config["worker_app"] - if worker_app == "synapse.app.homeserver": - # We need to special case all of this to pick up options that may - # be set in the main config file or in this worker config file. - worker_pidfile = ( - worker_config.get("pid_file") - or pidfile - ) - worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor - daemonize = worker_config.get("daemonize") or config.get("daemonize") - assert daemonize, "Main process must have daemonize set to true" - - # The master process doesn't support using worker_* config. - for key in worker_config: - if key == "worker_app": # But we allow worker_app - continue - assert not key.startswith("worker_"), \ - "Main process cannot use worker_* config" - else: - worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, "worker_daemonize") - worker_cache_factor = worker_config.get("synctl_cache_factor") - workers.append(Worker( - worker_app, worker_configfile, worker_pidfile, worker_cache_factor, - )) - - action = options.action - - if action == "stop" or action == "restart": - for worker in workers: - stop(worker.pidfile, worker.app) - - if start_stop_synapse: - stop(pidfile, "synapse.app.homeserver") - - # Wait for synapse to actually shutdown before starting it again - if action == "restart": - running_pids = [] - if start_stop_synapse and os.path.exists(pidfile): - running_pids.append(int(open(pidfile).read())) - for worker in workers: - if os.path.exists(worker.pidfile): - running_pids.append(int(open(worker.pidfile).read())) - if len(running_pids) > 0: - write("Waiting for process to exit before restarting...") - for running_pid in running_pids: - while pid_running(running_pid): - time.sleep(0.2) - write("All processes exited; now restarting...") - - if action == "start" or action == "restart": - if start_stop_synapse: - # Check if synapse is already running - if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): - abort("synapse.app.homeserver already running") - start(configfile) - - for worker in workers: - if worker.cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - - start_worker(worker.app, configfile, worker.configfile) - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - else: - os.environ.pop("SYNAPSE_CACHE_FACTOR", None) - - -if __name__ == "__main__": - main() diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1388a42b59..0a5f62b509 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -94,10 +94,7 @@ class UserDirectorySlaveStore( class UserDirectoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = UserDirectorySlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -191,7 +188,7 @@ def start(config_options): "Synapse user directory", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.user_dir" @@ -232,7 +229,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6980e5890e..9ccc5a80fc 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib + +from six.moves import urllib from prometheus_client import Counter @@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_user(self, service, user_id): if service.url is None: defer.returnValue(False) - uri = service.url + ("/users/%s" % urllib.quote(user_id)) + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: response = yield self.get_json(uri, { @@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_alias(self, service, alias): if service.url is None: defer.returnValue(False) - uri = service.url + ("/rooms/%s" % urllib.quote(alias)) + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: response = yield self.get_json(uri, { @@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, APP_SERVICE_PREFIX, kind, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: response = yield self.get_json(uri, fields) @@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: info = yield self.get_json(uri, {}) @@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient): txn_id = str(txn_id) uri = service.url + ("/transactions/%s" % - urllib.quote(txn_id)) + urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index 58c97a70af..79fe9c3dac 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -25,10 +25,10 @@ if __name__ == "__main__": try: config = HomeServerConfig.load_config("", sys.argv[3:]) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) - print (getattr(config, key)) + print(getattr(config, key)) sys.exit(0) else: sys.stderr.write("Unknown command %r\n" % (action,)) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 3d2e90dd5b..14dae65ea0 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -106,10 +106,7 @@ class Config(object): @classmethod def check_file(cls, file_path, config_name): if file_path is None: - raise ConfigError( - "Missing config for %s." - % (config_name,) - ) + raise ConfigError("Missing config for %s." % (config_name,)) try: os.stat(file_path) except OSError as e: @@ -128,9 +125,7 @@ class Config(object): if e.errno != errno.EEXIST: raise if not os.path.isdir(dir_path): - raise ConfigError( - "%s is not a directory" % (dir_path,) - ) + raise ConfigError("%s is not a directory" % (dir_path,)) return dir_path @classmethod @@ -156,21 +151,20 @@ class Config(object): return results def generate_config( - self, - config_dir_path, - server_name, - is_generating_file, - report_stats=None, + self, config_dir_path, server_name, is_generating_file, report_stats=None ): default_config = "# vim:ft=yaml\n" - default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", - config_dir_path=config_dir_path, - server_name=server_name, - is_generating_file=is_generating_file, - report_stats=report_stats, - )) + default_config += "\n\n".join( + dedent(conf) + for conf in self.invoke_all( + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + is_generating_file=is_generating_file, + report_stats=report_stats, + ) + ) config = yaml.load(default_config) @@ -178,23 +172,22 @@ class Config(object): @classmethod def load_config(cls, description, argv): - config_parser = argparse.ArgumentParser( - description=description, - ) + config_parser = argparse.ArgumentParser(description=description) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Where files such as certs and signing keys are stored when" - " their location is given explicitly in the config." - " Defaults to the directory containing the last config file", + " their location is given explicitly in the config." + " Defaults to the directory containing the last config file", ) config_args = config_parser.parse_args(argv) @@ -203,9 +196,7 @@ class Config(object): obj = cls() obj.read_config_files( - config_files, - keys_directory=config_args.keys_directory, - generate_keys=False, + config_files, keys_directory=config_args.keys_directory, generate_keys=False ) return obj @@ -213,38 +204,38 @@ class Config(object): def load_or_generate_config(cls, description, argv): config_parser = argparse.ArgumentParser(add_help=False) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--generate-config", action="store_true", - help="Generate a config file for the server name" + help="Generate a config file for the server name", ) config_parser.add_argument( "--report-stats", action="store", help="Whether the generated config reports anonymized usage statistics", - choices=["yes", "no"] + choices=["yes", "no"], ) config_parser.add_argument( "--generate-keys", action="store_true", - help="Generate any missing key files then exit" + help="Generate any missing key files then exit", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Used with 'generate-*' options to specify where files such as" - " certs and signing keys should be stored in, unless explicitly" - " specified in the config." + " certs and signing keys should be stored in, unless explicitly" + " specified in the config.", ) config_parser.add_argument( - "-H", "--server-name", - help="The server name to generate a config file for" + "-H", "--server-name", help="The server name to generate a config file for" ) config_args, remaining_args = config_parser.parse_known_args(argv) @@ -257,8 +248,8 @@ class Config(object): if config_args.generate_config: if config_args.report_stats is None: config_parser.error( - "Please specify either --report-stats=yes or --report-stats=no\n\n" + - MISSING_REPORT_STATS_SPIEL + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + MISSING_REPORT_STATS_SPIEL ) if not config_files: config_parser.error( @@ -287,26 +278,32 @@ class Config(object): config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), - is_generating_file=True + is_generating_file=True, ) obj.invoke_all("generate_files", config) config_file.write(config_str) - print(( - "A config file has been generated in %r for server name" - " %r with corresponding SSL keys and self-signed" - " certificates. Please review this file and customise it" - " to your needs." - ) % (config_path, server_name)) + print( + ( + "A config file has been generated in %r for server name" + " %r with corresponding SSL keys and self-signed" + " certificates. Please review this file and customise it" + " to your needs." + ) + % (config_path, server_name) + ) print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print(( - "Config file %r already exists. Generating any missing key" - " files." - ) % (config_path,)) + print( + ( + "Config file %r already exists. Generating any missing key" + " files." + ) + % (config_path,) + ) generate_keys = True parser = argparse.ArgumentParser( @@ -338,8 +335,7 @@ class Config(object): return obj - def read_config_files(self, config_files, keys_directory=None, - generate_keys=False): + def read_config_files(self, config_files, keys_directory=None, generate_keys=False): if not keys_directory: keys_directory = os.path.dirname(config_files[-1]) @@ -364,8 +360,9 @@ class Config(object): if "report_stats" not in config: raise ConfigError( - MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + - MISSING_REPORT_STATS_SPIEL + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + + "\n" + + MISSING_REPORT_STATS_SPIEL ) if generate_keys: @@ -399,16 +396,16 @@ def find_config_files(search_paths): for entry in os.listdir(config_path): entry_path = os.path.join(config_path, entry) if not os.path.isfile(entry_path): - print ( - "Found subdirectory in config directory: %r. IGNORING." - ) % (entry_path, ) + err = "Found subdirectory in config directory: %r. IGNORING." + print(err % (entry_path,)) continue if not entry.endswith(".yaml"): - print ( - "Found file in config directory that does not" - " end in '.yaml': %r. IGNORING." - ) % (entry_path, ) + err = ( + "Found file in config directory that does not end in " + "'.yaml': %r. IGNORING." + ) + print(err % (entry_path,)) continue files.append(entry_path) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index fe156b6930..93d70cff14 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + # This file can't be called email.py because if it is, we cannot: import email.utils +import logging +import os + +import pkg_resources -from ._base import Config +from ._base import Config, ConfigError + +logger = logging.getLogger(__name__) class EmailConfig(Config): @@ -38,7 +46,6 @@ class EmailConfig(Config): "smtp_host", "smtp_port", "notif_from", - "template_dir", "notif_template_html", "notif_template_text", ] @@ -62,9 +69,26 @@ class EmailConfig(Config): self.email_smtp_host = email_config["smtp_host"] self.email_smtp_port = email_config["smtp_port"] self.email_notif_from = email_config["notif_from"] - self.email_template_dir = email_config["template_dir"] self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_text = email_config["notif_template_text"] + + template_dir = email_config.get("template_dir") + # we need an absolute path, because we change directory after starting (and + # we don't yet know what auxilliary templates like mail.css we will need). + # (Note that loading as package_resources with jinja.PackageLoader doesn't + # work for the same reason.) + if not template_dir: + template_dir = pkg_resources.resource_filename( + 'synapse', 'res/templates' + ) + template_dir = os.path.abspath(template_dir) + + for f in self.email_notif_template_text, self.email_notif_template_html: + p = os.path.join(template_dir, f) + if not os.path.isfile(p): + raise ConfigError("Unable to find email template file %s" % (p, )) + self.email_template_dir = template_dir + self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) @@ -113,7 +137,9 @@ class EmailConfig(Config): # require_transport_security: False # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" # app_name: Matrix - # template_dir: res/templates + # # if template_dir is unset, uses the example templates that are part of + # # the Synapse distribution. + # #template_dir: res/templates # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt # notif_for_new_users: True diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 2fd9c48abf..10dd40159f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -21,7 +21,7 @@ from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig from .groups import GroupsConfig -from .jwt import JWTConfig +from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig @@ -31,6 +31,7 @@ from .push import PushConfig from .ratelimiting import RatelimitConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .room_directory import RoomDirectoryConfig from .saml2 import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig @@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, ConsentConfig, - ServerNoticesConfig, + ServerNoticesConfig, RoomDirectoryConfig, ): pass diff --git a/synapse/config/jwt.py b/synapse/config/jwt_config.py index 51e7f7e003..51e7f7e003 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt_config.py diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 3f187adfc8..e9a936118d 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -227,7 +227,22 @@ def setup_logging(config, use_worker_options=False): # # However this may not be too much of a problem if we are just writing to a file. observer = STDLibLogObserver() + + def _log(event): + + if "log_text" in event: + if event["log_text"].startswith("DNSDatagramProtocol starting on "): + return + + if event["log_text"].startswith("(UDP Port "): + return + + if event["log_text"].startswith("Timing out client"): + return + + return observer(event) + globalLogBeginner.beginLoggingTo( - [observer], + [_log], redirectStandardIO=not config.no_redirect_stdio, ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 685c78bc7f..0f41a6602e 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -15,10 +15,10 @@ from distutils.util import strtobool +from synapse.config._base import Config, ConfigError +from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols -from ._base import Config - class RegistrationConfig(Config): @@ -52,6 +52,10 @@ class RegistrationConfig(Config): ) self.auto_join_rooms = config.get("auto_join_rooms", []) + for room_alias in self.auto_join_rooms: + if not RoomAlias.is_valid(room_alias): + raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) + self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) self.disable_set_displayname = config.get("disable_set_displayname", False) self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) @@ -150,6 +154,12 @@ class RegistrationConfig(Config): #auto_join_rooms: # - "#example:example.com" + # Where auto_join_rooms are specified, setting this flag ensures that the + # the rooms exist by creating them when the first user on the + # homeserver registers. + # Setting to false means that if the rooms are not manually created, + # users cannot be auto-joined since they do not exist. + autocreate_auto_join_rooms: true """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fc909c1fac..06c62ab62c 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,7 +178,7 @@ class ContentRepositoryConfig(Config): def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") - return """ + return r""" # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py new file mode 100644 index 0000000000..9da13ab11b --- /dev/null +++ b/synapse/config/room_directory.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util import glob_to_regex + +from ._base import Config, ConfigError + + +class RoomDirectoryConfig(Config): + def read_config(self, config): + alias_creation_rules = config["alias_creation_rules"] + + self._alias_creation_rules = [ + _AliasRule(rule) + for rule in alias_creation_rules + ] + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # The `alias_creation` option controls who's allowed to create aliases + # on this server. + # + # The format of this option is a list of rules that contain globs that + # match against user_id and the new alias (fully qualified with server + # name). The action in the first rule that matches is taken, which can + # currently either be "allow" or "deny". + # + # If no rules match the request is denied. + alias_creation_rules: + - user_id: "*" + alias: "*" + action: allow + """ + + def is_alias_creation_allowed(self, user_id, alias): + """Checks if the given user is allowed to create the given alias + + Args: + user_id (str) + alias (str) + + Returns: + boolean: True if user is allowed to crate the alias + """ + for rule in self._alias_creation_rules: + if rule.matches(user_id, alias): + return rule.action == "allow" + + return False + + +class _AliasRule(object): + def __init__(self, rule): + action = rule["action"] + user_id = rule["user_id"] + alias = rule["alias"] + + if action in ("allow", "deny"): + self.action = action + else: + raise ConfigError( + "alias_creation_rules rules can only have action of 'allow'" + " or 'deny'" + ) + + try: + self._user_id_regex = glob_to_regex(user_id) + self._alias_regex = glob_to_regex(alias) + except Exception as e: + raise ConfigError("Failed to parse glob into regex: %s", e) + + def matches(self, user_id, alias): + """Tests if this rule matches the given user_id and alias. + + Args: + user_id (str) + alias (str) + + Returns: + boolean + """ + + # Note: The regexes are anchored at both ends + if not self._user_id_regex.match(user_id): + return False + + if not self._alias_regex.match(alias): + return False + + return True diff --git a/synapse/config/server.py b/synapse/config/server.py index 68a612e594..c1c7c0105e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -77,10 +77,15 @@ class ServerConfig(Config): self.max_mau_value = config.get( "max_mau_value", 0, ) + self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) + self.mau_trial_days = config.get( + "mau_trial_days", 0, + ) + # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") @@ -88,7 +93,7 @@ class ServerConfig(Config): # Admin uri to direct users at should their instance become blocked # due to resource constraints - self.admin_uri = config.get("admin_uri", None) + self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None @@ -352,7 +357,7 @@ class ServerConfig(Config): # Homeserver blocking # # How to reach the server admin, used in ResourceLimitError - # admin_uri: 'mailto:admin@server.com' + # admin_contact: 'mailto:admin@server.com' # # Global block config # @@ -365,6 +370,7 @@ class ServerConfig(Config): # Enables monthly active user checking # limit_usage_by_mau: False # max_mau_value: 50 + # mau_trial_days: 2 # # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. @@ -398,6 +404,23 @@ class ServerConfig(Config): " service on the given port.") +def is_threepid_reserved(config, threepid): + """Check the threepid against the reserved threepid config + Args: + config(ServerConfig) - to access server config attributes + threepid(dict) - The threepid to test for + + Returns: + boolean Is the threepid undertest reserved_user + """ + + for tp in config.mau_limits_reserved_threepids: + if (threepid['medium'] == tp['medium'] + and threepid['address'] == tp['address']): + return True + return False + + def read_gc_thresholds(thresholds): """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 1a391adec1..02b76dfcfb 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -123,6 +123,6 @@ class ClientTLSOptionsFactory(object): def get_options(self, host): return ClientTLSOptions( - host.decode('utf-8'), + host, CertificateOptions(verify=False).getContext() ) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c20a32096a..080c81f14b 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ import logging from canonicaljson import json from twisted.internet import defer, reactor +from twisted.internet.error import ConnectError from twisted.internet.protocol import Factory +from twisted.names.error import DomainError from twisted.web.http import HTTPClient from synapse.http.endpoint import matrix_federation_endpoint @@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): server_response, server_certificate = yield protocol.remote_key defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: - logger.exception("Error getting key for %r" % (server_name,)) - if e.status.startswith("4"): + logger.warn("Error getting key for %r: %s", server_name, e) + if e.status.startswith(b"4"): # Don't retry for 4xx responses. raise IOError("Cannot get key for %r" % server_name) - except Exception as e: - logger.exception(e) + except (ConnectError, DomainError) as e: + logger.warn("Error getting key for %r: %s", server_name, e) + except Exception: + logger.exception("Error getting key for %r", server_name) raise IOError("Cannot get key for %r" % server_name) @@ -78,6 +82,12 @@ class SynapseKeyClientProtocol(HTTPClient): self._peer = self.transport.getPeer() logger.debug("Connected to %s", self._peer) + if not isinstance(self.path, bytes): + self.path = self.path.encode('ascii') + + if not isinstance(self.host, bytes): + self.host = self.host.encode('ascii') + self.sendCommand(b"GET", self.path) if self.host: self.sendHeader(b"Host", self.host) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 30e2742102..d89f94c219 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,9 +16,10 @@ import hashlib import logging -import urllib from collections import namedtuple +from six.moves import urllib + from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -40,6 +41,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyclient import fetch_server_key from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( + LoggingContext, PreserveLoggingContext, preserve_fn, run_in_background, @@ -216,23 +218,34 @@ class Keyring(object): servers have completed. Follows the synapse rules of logcontext preservation. """ + loop_count = 1 while True: wait_on = [ - self.key_downloads[server_name] + (server_name, self.key_downloads[server_name]) for server_name in server_names if server_name in self.key_downloads ] - if wait_on: - with PreserveLoggingContext(): - yield defer.DeferredList(wait_on) - else: + if not wait_on: break + logger.info( + "Waiting for existing lookups for %s to complete [loop %i]", + [w[0] for w in wait_on], loop_count, + ) + with PreserveLoggingContext(): + yield defer.DeferredList((w[1] for w in wait_on)) + + loop_count += 1 + + ctx = LoggingContext.current_context() def rm(r, server_name_): - self.key_downloads.pop(server_name_, None) + with PreserveLoggingContext(ctx): + logger.debug("Releasing key lookup lock on %s", server_name_) + self.key_downloads.pop(server_name_, None) return r for server_name, deferred in server_to_deferred.items(): + logger.debug("Got key lookup lock on %s", server_name) self.key_downloads[server_name] = deferred deferred.addBoth(rm, server_name) @@ -432,7 +445,7 @@ class Keyring(object): # an incoming request. query_response = yield self.client.post_json( destination=perspective_name, - path=b"/_matrix/key/v2/query", + path="/_matrix/key/v2/query", data={ u"server_keys": { server_name: { @@ -513,8 +526,8 @@ class Keyring(object): (response, tls_certificate) = yield fetch_server_key( server_name, self.hs.tls_client_options_factory, - path=(b"/_matrix/key/v2/server/%s" % ( - urllib.quote(requested_key_id), + path=("/_matrix/key/v2/server/%s" % ( + urllib.parse.quote(requested_key_id), )).encode("ascii"), ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 6baeccca38..d4d4474847 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -98,9 +98,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise SynapseError( + raise AuthError( 403, - "Room %r does not exist" % (event.room_id,) + "No create event in auth events", ) creating_domain = get_domain_from_id(event.room_id) @@ -155,10 +155,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): if user_level < invite_level: raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) + 403, "You don't have permission to invite users", ) else: logger.debug("Allowing! %s", event) @@ -305,7 +302,7 @@ def _is_membership_change_allowed(event, auth_events): if user_level < invite_level: raise AuthError( - 403, "You cannot invite user %s." % target_user_id + 403, "You don't have permission to invite users", ) elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: @@ -693,7 +690,7 @@ def auth_types_for_event(event): auth_types = [] auth_types.append((EventTypes.PowerLevels, "", )) - auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Member, event.sender, )) auth_types.append((EventTypes.Create, "", )) if event.type == EventTypes.Member: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 51f9084b90..12f1eb0a3e 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -13,13 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from distutils.util import strtobool + +import six + from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents -# bugs where we accidentally share e.g. signature dicts. However, converting -# a dict to frozen_dicts is expensive. -USE_FROZEN_DICTS = True +# bugs where we accidentally share e.g. signature dicts. However, converting a +# dict to frozen_dicts is expensive. +# +# NOTE: This is overridden by the configuration by the Synapse worker apps, but +# for the sake of tests, it is set here while it cannot be configured on the +# homeserver object itself. +USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) class _EventInternalMetadata(object): @@ -147,6 +156,9 @@ class EventBase(object): def items(self): return list(self._event_dict.items()) + def keys(self): + return six.iterkeys(self._event_dict) + class FrozenEvent(EventBase): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c11798093d..b7ad729c63 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import namedtuple import six from twisted.internet import defer +from twisted.internet.defer import DeferredList -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict +from synapse.types import get_domain_from_id from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) @@ -133,34 +136,45 @@ class FederationBase(object): * throws a SynapseError if the signature check failed. The deferreds run their callbacks in the sentinel logcontext. """ - - redacted_pdus = [ - prune_event(pdu) - for pdu in pdus - ] - - deferreds = self.keyring.verify_json_objects_for_server([ - (p.origin, p.get_pdu_json()) - for p in redacted_pdus - ]) + deferreds = _check_sigs_on_pdus(self.keyring, pdus) ctx = logcontext.LoggingContext.current_context() - def callback(_, pdu, redacted): + def callback(_, pdu): with logcontext.PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted + # let's try to distinguish between failures because the event was + # redacted (which are somewhat expected) vs actual ball-tampering + # incidents. + # + # This is just a heuristic, so we just assume that if the keys are + # about the same between the redacted and received events, then the + # received event was probably a redacted copy (but we then use our + # *actual* redacted copy to be on the safe side.) + redacted_event = prune_event(pdu) + if ( + set(redacted_event.keys()) == set(pdu.keys()) and + set(six.iterkeys(redacted_event.content)) + == set(six.iterkeys(pdu.content)) + ): + logger.info( + "Event %s seems to have been redacted; using our redacted " + "copy", + pdu.event_id, + ) + else: + logger.warning( + "Event %s content has been tampered, redacting", + pdu.event_id, pdu.get_pdu_json(), + ) + return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json() ) - return redacted + return prune_event(pdu) return pdu @@ -168,21 +182,121 @@ class FederationBase(object): failure.trap(SynapseError) with logcontext.PreserveLoggingContext(ctx): logger.warn( - "Signature check failed for %s", - pdu.event_id, + "Signature check failed for %s: %s", + pdu.event_id, failure.getErrorMessage(), ) return failure - for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( callback, errback, - callbackArgs=[pdu, redacted], + callbackArgs=[pdu], errbackArgs=[pdu], ) return deferreds +class PduToCheckSig(namedtuple("PduToCheckSig", [ + "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds", +])): + pass + + +def _check_sigs_on_pdus(keyring, pdus): + """Check that the given events are correctly signed + + Args: + keyring (synapse.crypto.Keyring): keyring object to do the checks + pdus (Collection[EventBase]): the events to be checked + + Returns: + List[Deferred]: a Deferred for each event in pdus, which will either succeed if + the signatures are valid, or fail (with a SynapseError) if not. + """ + + # (currently this is written assuming the v1 room structure; we'll probably want a + # separate function for checking v2 rooms) + + # we want to check that the event is signed by: + # + # (a) the server which created the event_id + # + # (b) the sender's server. + # + # - except in the case of invites created from a 3pid invite, which are exempt + # from this check, because the sender has to match that of the original 3pid + # invite, but the event may come from a different HS, for reasons that I don't + # entirely grok (why do the senders have to match? and if they do, why doesn't the + # joining server ask the inviting server to do the switcheroo with + # exchange_third_party_invite?). + # + # That's pretty awful, since redacting such an invite will render it invalid + # (because it will then look like a regular invite without a valid signature), + # and signatures are *supposed* to be valid whether or not an event has been + # redacted. But this isn't the worst of the ways that 3pid invites are broken. + # + # let's start by getting the domain for each pdu, and flattening the event back + # to JSON. + pdus_to_check = [ + PduToCheckSig( + pdu=p, + redacted_pdu_json=prune_event(p).get_pdu_json(), + event_id_domain=get_domain_from_id(p.event_id), + sender_domain=get_domain_from_id(p.sender), + deferreds=[], + ) + for p in pdus + ] + + # first make sure that the event is signed by the event_id's domain + deferreds = keyring.verify_json_objects_for_server([ + (p.event_id_domain, p.redacted_pdu_json) + for p in pdus_to_check + ]) + + for p, d in zip(pdus_to_check, deferreds): + p.deferreds.append(d) + + # now let's look for events where the sender's domain is different to the + # event id's domain (normally only the case for joins/leaves), and add additional + # checks. + pdus_to_check_sender = [ + p for p in pdus_to_check + if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu) + ] + + more_deferreds = keyring.verify_json_objects_for_server([ + (p.sender_domain, p.redacted_pdu_json) + for p in pdus_to_check_sender + ]) + + for p, d in zip(pdus_to_check_sender, more_deferreds): + p.deferreds.append(d) + + # replace lists of deferreds with single Deferreds + return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] + + +def _flatten_deferred_list(deferreds): + """Given a list of one or more deferreds, either return the single deferred, or + combine into a DeferredList. + """ + if len(deferreds) > 1: + return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) + else: + assert len(deferreds) == 1 + return deferreds[0] + + +def _is_invite_via_3pid(event): + return ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + def event_from_pdu_json(pdu_json, outlier=False): """Construct a FrozenEvent from an event json received over federation diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9f3c2d352..d05ed91d64 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -66,6 +66,14 @@ class FederationClient(FederationBase): self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() + self._get_pdu_cache = ExpiringCache( + cache_name="get_pdu_cache", + clock=self._clock, + max_len=1000, + expiry_ms=120 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -82,17 +90,6 @@ class FederationClient(FederationBase): if destination_dict: self.pdu_destination_tried[event_id] = destination_dict - def start_get_pdu_cache(self): - self._get_pdu_cache = ExpiringCache( - cache_name="get_pdu_cache", - clock=self._clock, - max_len=1000, - expiry_ms=120 * 1000, - reset_expiry_on_get=False, - ) - - self._get_pdu_cache.start() - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False, ignore_backoff=False): @@ -212,8 +209,6 @@ class FederationClient(FederationBase): Will attempt to get the PDU from each destination in the list until one succeeds. - This will persist the PDU locally upon receipt. - Args: destinations (list): Which home servers to query event_id (str): event to fetch @@ -229,10 +224,9 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. - if self._get_pdu_cache: - ev = self._get_pdu_cache.get(event_id) - if ev: - defer.returnValue(ev) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) @@ -271,10 +265,10 @@ class FederationClient(FederationBase): event_id, destination, e, ) except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: - logger.info(e.message) + logger.info(str(e)) continue except Exception as e: pdu_attempts[destination] = now @@ -285,7 +279,7 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None and signed_pdu: + if signed_pdu: self._get_pdu_cache[event_id] = signed_pdu defer.returnValue(signed_pdu) @@ -293,8 +287,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. + """Requests all of the room state at a given event from a remote home server. Args: destination (str): The remote homeserver to query for the state. @@ -302,9 +295,10 @@ class FederationClient(FederationBase): event_id (str): The id of the event we want the state at. Returns: - Deferred: Results in a list of PDUs. + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. """ - try: # First we try and ask for just the IDs, as thats far quicker if # we have most of the state and auth_chain already. @@ -510,7 +504,7 @@ class FederationClient(FederationBase): else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.message, + description, destination, e.code, e.args[0], ) except Exception: logger.warn( @@ -875,7 +869,7 @@ class FederationClient(FederationBase): except Exception as e: logger.exception( "Failed to send_third_party_invite via %s: %s", - destination, e.message + destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3e0cd294a1..0f9302a6a8 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re import six from six import iteritems @@ -44,8 +43,10 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import get_domain_from_id +from synapse.util import glob_to_regex from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logcontext import nested_logging_context from synapse.util.logutils import log_function # when processing incoming transactions, we try to handle multiple rooms in @@ -99,7 +100,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function - def on_incoming_transaction(self, transaction_data): + def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -108,34 +109,33 @@ class FederationServer(FederationBase): if not transaction.transaction_id: raise Exception("Transaction missing transaction_id") - if not transaction.origin: - raise Exception("Transaction missing origin") logger.debug("[%s] Got transaction", transaction.transaction_id) # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with (yield self._transaction_linearizer.queue( - (transaction.origin, transaction.transaction_id), + (origin, transaction.transaction_id), )): result = yield self._handle_incoming_transaction( - transaction, request_time, + origin, transaction, request_time, ) defer.returnValue(result) @defer.inlineCallbacks - def _handle_incoming_transaction(self, transaction, request_time): + def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: + origin (unicode): the server making the request transaction (Transaction): incoming transaction request_time (int): timestamp that the HTTP request arrived at Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(transaction) + response = yield self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -149,7 +149,7 @@ class FederationServer(FederationBase): received_pdus_counter.inc(len(transaction.pdus)) - origin_host, _ = parse_server_name(transaction.origin) + origin_host, _ = parse_server_name(origin) pdus_by_room = {} @@ -188,21 +188,22 @@ class FederationServer(FederationBase): for pdu in pdus_by_room[room_id]: event_id = pdu.event_id - try: - yield self._handle_received_pdu( - transaction.origin, pdu - ) - pdu_results[event_id] = {} - except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s: %s", - event_id, f.getTraceback().rstrip(), - ) + with nested_logging_context(event_id): + try: + yield self._handle_received_pdu( + origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), + ) yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), @@ -212,7 +213,7 @@ class FederationServer(FederationBase): if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): yield self.received_edu( - transaction.origin, + origin, edu.edu_type, edu.content ) @@ -224,6 +225,7 @@ class FederationServer(FederationBase): logger.debug("Returning: %s", str(response)) yield self.transaction_actions.set_response( + origin, transaction, 200, response ) @@ -505,19 +507,19 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," - " limit: %d, min_depth: %d", - earliest_events, latest_events, limit, min_depth + " limit: %d", + earliest_events, latest_events, limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, min_depth + origin, room_id, earliest_events, latest_events, limit, ) if len(missing_events) < 5: @@ -618,7 +620,7 @@ class FederationServer(FederationBase): ) yield self.handler.on_receive_pdu( - origin, pdu, get_missing=True, sent_to_us_directly=True, + origin, pdu, sent_to_us_directly=True, ) def __str__(self): @@ -727,22 +729,10 @@ def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) return False - regex = _glob_to_regex(acl_entry) + regex = glob_to_regex(acl_entry) return regex.match(server_name) -def _glob_to_regex(glob): - res = '' - for c in glob: - if c == '*': - res = res + '.*' - elif c == '?': - res = res + '.' - else: - res = res + re.escape(c) - return re.compile(res + "\\Z", re.IGNORECASE) - - class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. @@ -798,7 +788,7 @@ class FederationHandlerRegistry(object): yield handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: + except Exception: logger.exception("Failed to handle edu %r", edu_type) def on_query(self, query_type, args): @@ -838,9 +828,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): ) return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, + edu_type=edu_type, + origin=origin, + content=content, ) def on_query(self, query_type, args): @@ -851,6 +841,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): return handler(args) return self._get_query_client( - query_type=query_type, - args=args, + query_type=query_type, + args=args, ) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 9146215c21..74ffd13b4f 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -36,7 +36,7 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, transaction): + def have_responded(self, origin, transaction): """ Have we already responded to a transaction with the same id and origin? @@ -50,11 +50,11 @@ class TransactionActions(object): "transaction_id") return self.store.get_received_txn_response( - transaction.transaction_id, transaction.origin + transaction.transaction_id, origin ) @log_function - def set_response(self, transaction, code, response): + def set_response(self, origin, transaction, code, response): """ Persist how we responded to a transaction. Returns: @@ -66,7 +66,7 @@ class TransactionActions(object): return self.store.set_received_txn_response( transaction.transaction_id, - transaction.origin, + origin, code, response, ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0bb468385d..6f5995735a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -32,7 +32,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems from sortedcontainers import SortedDict @@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object): user_ids = set( user_id - for uids in itervalues(self.presence_changed) + for uids in self.presence_changed.values() for user_id in uids ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 94d7423d01..3fdd63be95 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -137,26 +137,6 @@ class TransactionQueue(object): self._processing_pending_presence = False - def can_send_to(self, destination): - """Can we send messages to the given server? - - We can't send messages to ourselves. If we are running on localhost - then we can only federation with other servers running on localhost. - Otherwise we only federate with servers on a public domain. - - Args: - destination(str): The server we are possibly trying to send to. - Returns: - bool: True if we can send to the server. - """ - - if destination == self.server_name: - return False - if self.server_name.startswith("localhost"): - return destination.startswith("localhost") - else: - return not destination.startswith("localhost") - def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -279,10 +259,7 @@ class TransactionQueue(object): self._order += 1 destinations = set(destinations) - destinations = set( - dest for dest in destinations if self.can_send_to(dest) - ) - + destinations.discard(self.server_name) logger.debug("Sending to: %s", str(destinations)) if not destinations: @@ -358,7 +335,7 @@ class TransactionQueue(object): for destinations, states in hosts_and_states: for destination in destinations: - if not self.can_send_to(destination): + if destination == self.server_name: continue self.pending_presence_by_dest.setdefault( @@ -377,7 +354,8 @@ class TransactionQueue(object): content=content, ) - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") return sent_edus_counter.inc() @@ -392,10 +370,8 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) def send_device_messages(self, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending device update to ourselves") return self._attempt_new_transaction(destination) @@ -463,7 +439,19 @@ class TransactionQueue(object): # pending_transactions flag. pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + + # We can only include at most 50 PDUs per transactions + pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] + if leftover_pdus: + self.pending_pdus_by_dest[destination] = leftover_pdus + pending_edus = self.pending_edus_by_dest.pop(destination, []) + + # We can only include at most 100 EDUs per transactions + pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] + if leftover_edus: + self.pending_edus_by_dest[destination] = leftover_edus + pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_edus.extend( @@ -645,14 +633,6 @@ class TransactionQueue(object): transaction, json_data_cb ) code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) except HttpResponseException as e: code = e.code response = e.response @@ -669,19 +649,24 @@ class TransactionQueue(object): destination, txn_id, code ) - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - yield self.transaction_actions.delivered( transaction, code, response ) - logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) - if code != 200: + if code == 200: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "TX [%s] {%s} Remote returned error for %s: %s", + destination, txn_id, e_id, r, + ) + else: for p in pdus: - logger.info( - "Failed to send event %s to %s", p.event_id, destination + logger.warn( + "TX [%s] {%s} Failed to send event %s", + destination, txn_id, p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index b4fbe2c9d5..edba5a9808 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import urllib + +from six.moves import urllib from twisted.internet import defer @@ -106,7 +107,7 @@ class TransportLayerClient(object): dest (str) room_id (str) event_tuples (list) - limt (int) + limit (int) Returns: Deferred: Results in a dict received from the remote homeserver. @@ -142,9 +143,17 @@ class TransportLayerClient(object): transaction (Transaction) Returns: - Deferred: Results of the deferred is a tuple in the form of - (response_code, response_body) where the response_body is a - python dict decoded from json + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ logger.debug( "send_data dest=%s, txid=%s", @@ -169,11 +178,6 @@ class TransportLayerClient(object): backoff_on_404=True, # If we get a 404 the other side has gone ) - logger.debug( - "send_data dest=%s, txid=%s, got response: 200", - transaction.destination, transaction.transaction_id, - ) - defer.returnValue(response) @defer.inlineCallbacks @@ -951,4 +955,4 @@ def _create_path(prefix, path, *args): Returns: str """ - return prefix + path % tuple(urllib.quote(arg, "") for arg in args) + return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 77969a4f38..7288d49074 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -90,8 +90,8 @@ class Authenticator(object): @defer.inlineCallbacks def authenticate_request(self, request, content): json_request = { - "method": request.method, - "uri": request.uri, + "method": request.method.decode('ascii'), + "uri": request.uri.decode('ascii'), "destination": self.server_name, "signatures": {}, } @@ -252,7 +252,7 @@ class BaseFederationServlet(object): by the callback method. None if the request has already been handled. """ content = None - if request.method in ["PUT", "POST"]: + if request.method in [b"PUT", b"POST"]: # TODO: Handle other method types? other content types? content = parse_json_object_from_request(request) @@ -261,10 +261,10 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.exception("authenticate_request failed") + logger.warn("authenticate_request failed: missing authentication") raise - except Exception: - logger.exception("authenticate_request failed") + except Exception as e: + logger.warn("authenticate_request failed: %s", e) raise if origin: @@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - transaction_data + origin, transaction_data, ) except Exception: logger.exception("on_incoming_transaction failed") @@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet): return self.handler.on_context_state_request( origin, context, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet): return self.handler.on_state_ids_request( origin, room_id, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/" def on_GET(self, origin, content, query, context): - versions = query["v"] - limits = query["limit"] + versions = [x.decode('ascii') for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) - if not limits: + if not limit: return defer.succeed((400, {"error": "Did not include limit param"})) - limit = int(limits[-1]) - return self.handler.on_backfill_request(origin, context, versions, limit) @@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k: v[0].decode("utf-8") for k, v in query.items()} + {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} ) @@ -562,7 +560,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): @defer.inlineCallbacks def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) - min_depth = int(content.get("min_depth", 0)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) @@ -571,7 +568,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, - min_depth=min_depth, limit=limit, ) @@ -630,14 +626,14 @@ class OpenIdUserInfo(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query): - token = query.get("access_token", [None])[0] + token = query.get(b"access_token", [None])[0] if token is None: defer.returnValue((401, { "errcode": "M_MISSING_TOKEN", "error": "Access Token required" })) return - user_id = yield self.handler.on_openid_userinfo(token) + user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) if user_id is None: defer.returnValue((401, { diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f0f89af7dc..17eedf4dbf 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -28,6 +28,7 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import log_failure from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.metrics import Measure @@ -36,17 +37,6 @@ logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") -def log_failure(failure): - logger.error( - "Application Services Failure", - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject() - ) - ) - - class ApplicationServicesHandler(object): def __init__(self, hs): @@ -112,7 +102,10 @@ class ApplicationServicesHandler(object): if not self.started_scheduler: def start_scheduler(): - return self.scheduler.start().addErrback(log_failure) + return self.scheduler.start().addErrback( + log_failure, "Application Services Failure", + ) + run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4a81bd2ba9..329e3c7d71 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -22,7 +22,7 @@ import bcrypt import pymacaroons from canonicaljson import json -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils @@ -37,8 +37,8 @@ from synapse.api.errors import ( ) from synapse.module_api import ModuleApi from synapse.types import UserID +from synapse.util import logcontext from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable from ._base import BaseHandler @@ -884,40 +884,32 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode('ascii') - return make_deferred_yieldable( - threads.deferToThreadPool( - self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash - ), - ) + return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. Args: password (unicode): Password to hash. - stored_hash (unicode): Expected hash value. + stored_hash (bytes): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash.encode('utf8') + stored_hash ) if stored_hash: - return make_deferred_yieldable( - threads.deferToThreadPool( - self.hs.get_reactor(), - self.hs.get_reactor().getThreadPool(), - _do_validate_hash, - ), - ) + if not isinstance(stored_hash, bytes): + stored_hash = stored_hash.encode('ascii') + + return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: return defer.succeed(False) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 3bd59d2dd4..7f9435179a 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -17,8 +17,8 @@ import logging from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester -from synapse.util.logcontext import run_in_background from ._base import BaseHandler @@ -125,7 +125,7 @@ class DeactivateAccountHandler(BaseHandler): None """ if not self._user_parter_running: - run_in_background(self._user_parter_loop) + run_as_background_process("user_parter_loop", self._user_parter_loop) @defer.inlineCallbacks def _user_parter_loop(self): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ef866da1b6..7d67bf803a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -20,7 +20,14 @@ import string from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + NotFoundError, + StoreError, + SynapseError, +) from synapse.types import RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -36,6 +43,7 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.config = hs.config self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -73,43 +81,75 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def create_association(self, user_id, room_alias, room_id, servers=None): - # association creation for human users - # TODO(erikj): Do user auth. + def create_association(self, requester, room_alias, room_id, servers=None, + send_event=True): + """Attempt to create a new alias - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise SynapseError( - 403, "This user is not permitted to create this alias", - ) + Args: + requester (Requester) + room_alias (RoomAlias) + room_id (str) + servers (list[str]|None): List of servers that others servers + should try and join via + send_event (bool): Whether to send an updated m.room.aliases event - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) - if not can_create: - raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE - ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) + Returns: + Deferred + """ - @defer.inlineCallbacks - def create_appservice_association(self, service, room_alias, room_id, - servers=None): - if not service.is_interested_in_alias(room_alias.to_string()): - raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE + user_id = requester.user.to_string() + + service = requester.app_service + if service: + if not service.is_interested_in_alias(room_alias.to_string()): + raise SynapseError( + 400, "This application service has not reserved" + " this kind of alias.", errcode=Codes.EXCLUSIVE + ) + else: + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + raise AuthError( + 403, "This user is not permitted to create this alias", + ) + + if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()): + # Lets just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError( + 403, "Not allowed to create alias", + ) + + can_create = yield self.can_modify_alias( + room_alias, + user_id=user_id ) + if not can_create: + raise AuthError( + 400, "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) - # association creation for app services - yield self._create_association(room_alias, room_id, servers) + yield self._create_association(room_alias, room_id, servers, creator=user_id) + if send_event: + yield self.send_room_alias_update_event( + requester, + room_id + ) @defer.inlineCallbacks - def delete_association(self, requester, user_id, room_alias): + def delete_association(self, requester, room_alias): # association deletion for human users - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + user_id = requester.user.to_string() + + try: + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown room alias") + raise + if not can_delete: raise AuthError( 403, "You don't have permission to delete the alias.", @@ -130,7 +170,6 @@ class DirectoryHandler(BaseHandler): try: yield self.send_room_alias_update_event( requester, - requester.user.to_string(), room_id ) @@ -248,7 +287,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def send_room_alias_update_event(self, requester, user_id, room_id): + def send_room_alias_update_event(self, requester, room_id): aliases = yield self.store.get_aliases_for_room(room_id) yield self.event_creation_handler.create_and_send_nonmember_event( @@ -257,7 +296,7 @@ class DirectoryHandler(BaseHandler): "type": EventTypes.Aliases, "state_key": self.hs.hostname, "room_id": room_id, - "sender": user_id, + "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, ratelimit=False @@ -320,7 +359,7 @@ class DirectoryHandler(BaseHandler): def _user_can_delete_alias(self, alias, user_id): creator = yield self.store.get_room_alias_creator(alias.to_string()) - if creator and creator == user_id: + if creator is not None and creator == user_id: defer.returnValue(True) is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5816bf8b4f..9dc46aa15f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -330,7 +330,8 @@ class E2eKeysHandler(object): (algorithm, key_id, ex_json, key) ) else: - new_keys.append((algorithm, key_id, encode_canonical_json(key))) + new_keys.append(( + algorithm, key_id, encode_canonical_json(key).decode('ascii'))) yield self.store.add_e2e_one_time_keys( user_id, device_id, time_now, new_keys @@ -340,7 +341,7 @@ class E2eKeysHandler(object): def _exception_to_failure(e): if isinstance(e, CodeMessageException): return { - "status": e.code, "message": e.message, + "status": e.code, "message": str(e), } if isinstance(e, NotRetryingDestination): @@ -358,7 +359,7 @@ def _exception_to_failure(e): # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. return { - "status": 503, "message": str(e.message), + "status": 503, "message": str(e), } diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py new file mode 100644 index 0000000000..5edb3cfe04 --- /dev/null +++ b/synapse/handlers/e2e_room_keys.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.api.errors import RoomKeysVersionError, StoreError, SynapseError +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +class E2eRoomKeysHandler(object): + """ + Implements an optional realtime backup mechanism for encrypted E2E megolm room keys. + This gives a way for users to store and recover their megolm keys if they lose all + their clients. It should also extend easily to future room key mechanisms. + The actual payload of the encrypted keys is completely opaque to the handler. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + + # Used to lock whenever a client is uploading key data. This prevents collisions + # between clients trying to upload the details of a new session, given all + # clients belonging to a user will receive and try to upload a new session at + # roughly the same time. Also used to lock out uploads when the key is being + # changed. + self._upload_linearizer = Linearizer("upload_room_keys_lock") + + @defer.inlineCallbacks + def get_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + See EndToEndRoomKeyStore.get_e2e_room_keys for full details. + + Args: + user_id(str): the user whose keys we're getting + version(str): the version ID of the backup we're getting keys from + room_id(string): room ID to get keys for, for None to get keys for all rooms + session_id(string): session ID to get keys for, for None to get keys for all + sessions + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + # we deliberately take the lock to get keys so that changing the version + # works atomically + with (yield self._upload_linearizer.queue(user_id)): + results = yield self.store.get_e2e_room_keys( + user_id, version, room_id, session_id + ) + + if results['rooms'] == {}: + raise SynapseError(404, "No room_keys found") + + defer.returnValue(results) + + @defer.inlineCallbacks + def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. + + Args: + user_id(str): the user whose backup we're deleting + version(str): the version ID of the backup we're deleting + room_id(string): room ID to delete keys for, for None to delete keys for all + rooms + session_id(string): session ID to delete keys for, for None to delete keys + for all sessions + Returns: + A deferred of the deletion transaction + """ + + # lock for consistency with uploading + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + + @defer.inlineCallbacks + def upload_room_keys(self, user_id, version, room_keys): + """Bulk upload a list of room keys into a given backup version, asserting + that the given version is the current backup version. room_keys are merged + into the current backup as described in RoomKeysServlet.on_PUT(). + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_keys(dict): a nested dict describing the room_keys we're setting: + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + + Raises: + SynapseError: with code 404 if there are no versions defined + RoomKeysVersionError: if the uploaded version is not the current version + """ + + # TODO: Validate the JSON to make sure it has the right keys. + + # XXX: perhaps we should use a finer grained lock here? + with (yield self._upload_linearizer.queue(user_id)): + + # Check that the version we're trying to upload is the current version + try: + version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise + + if version_info['version'] != version: + # Check that the version we're trying to upload actually exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version, + ) + # if we get this far, the version must exist + raise RoomKeysVersionError(current_version=version_info['version']) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise + + # go through the room_keys. + # XXX: this should/could be done concurrently, given we're in a lock. + for room_id, room in iteritems(room_keys['rooms']): + for session_id, session in iteritems(room['sessions']): + yield self._upload_room_key( + user_id, version, room_id, session_id, session + ) + + @defer.inlineCallbacks + def _upload_room_key(self, user_id, version, room_id, session_id, room_key): + """Upload a given room_key for a given room and session into a given + version of the backup. Merges the key with any which might already exist. + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + """ + + # get the room_key for this particular row + current_room_key = None + try: + current_room_key = yield self.store.get_e2e_room_key( + user_id, version, room_id, session_id + ) + except StoreError as e: + if e.code == 404: + pass + else: + raise + + if self._should_replace_room_key(current_room_key, room_key): + yield self.store.set_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + + @staticmethod + def _should_replace_room_key(current_room_key, room_key): + """ + Determine whether to replace a given current_room_key (if any) + with a newly uploaded room_key backup + + Args: + current_room_key (dict): Optional, the current room_key dict if any + room_key (dict): The new room_key dict which may or may not be fit to + replace the current_room_key + + Returns: + True if current_room_key should be replaced by room_key in the backup + """ + + if current_room_key: + # spelt out with if/elifs rather than nested boolean expressions + # purely for legibility. + + if room_key['is_verified'] and not current_room_key['is_verified']: + return True + elif ( + room_key['first_message_index'] < + current_room_key['first_message_index'] + ): + return True + elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + return True + else: + return False + return True + + @defer.inlineCallbacks + def create_version(self, user_id, version_info): + """Create a new backup version. This automatically becomes the new + backup version for the user's keys; previous backups will no longer be + writeable to. + + Args: + user_id(str): the user whose backup version we're creating + version_info(dict): metadata about the new version being created + + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + Returns: + A deferred of a string that gives the new version number. + """ + + # TODO: Validate the JSON to make sure it has the right keys. + + # lock everyone out until we've switched version + with (yield self._upload_linearizer.queue(user_id)): + new_version = yield self.store.create_e2e_room_keys_version( + user_id, version_info + ) + defer.returnValue(new_version) + + @defer.inlineCallbacks + def get_version_info(self, user_id, version=None): + """Get the info about a given version of the user's backup + + Args: + user_id(str): the user whose current backup version we're querying + version(str): Optional; if None gives the most recent version + otherwise a historical one. + Raises: + StoreError: code 404 if the requested backup version doesn't exist + Returns: + A deferred of a info dict that gives the info about the new version. + + { + "version": "1234", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + + with (yield self._upload_linearizer.queue(user_id)): + res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + defer.returnValue(res) + + @defer.inlineCallbacks + def delete_version(self, user_id, version=None): + """Deletes a given version of the user's e2e_room_keys backup + + Args: + user_id(str): the user whose current backup version we're deleting + version(str): the version id of the backup being deleted + Raises: + StoreError: code 404 if this backup version doesn't exist + """ + + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_keys_version(user_id, version) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3dd107a285..cd5b9bbb19 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -18,7 +18,6 @@ import itertools import logging -import sys import six from six import iteritems, itervalues @@ -54,7 +53,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet -from synapse.state import resolve_events_with_factory +from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError from synapse.util.async_helpers import Linearizer @@ -69,6 +68,27 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +def shortstr(iterable, maxitems=5): + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable (Iterable): iterable to truncate + maxitems (int): number of items to return before truncating + + Returns: + unicode + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + + class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -85,7 +105,7 @@ class FederationHandler(BaseHandler): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -114,9 +134,8 @@ class FederationHandler(BaseHandler): self._room_pdu_linearizer = Linearizer("fed_room_pdu") @defer.inlineCallbacks - @log_function def on_receive_pdu( - self, origin, pdu, get_missing=True, sent_to_us_directly=False, + self, origin, pdu, sent_to_us_directly=False, ): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -125,14 +144,23 @@ class FederationHandler(BaseHandler): origin (str): server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu (FrozenEvent): received PDU - get_missing (bool): True if we should fetch missing prev_events + sent_to_us_directly (bool): True if this event was pushed to us; False if + we pulled it as the result of a missing prev_event. Returns (Deferred): completes with None """ + room_id = pdu.room_id + event_id = pdu.event_id + + logger.info( + "[%s %s] handling received PDU: %s", + room_id, event_id, pdu, + ) + # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - pdu.event_id, + event_id, allow_none=True, allow_rejected=True, ) @@ -147,7 +175,7 @@ class FederationHandler(BaseHandler): ) ) if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) + logger.debug("[%s %s]: Already seen pdu", room_id, event_id) return # do some initial sanity-checking of the event. In particular, make @@ -156,6 +184,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: + logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) raise FederationError( "ERROR", err.code, @@ -165,10 +194,12 @@ class FederationHandler(BaseHandler): # If we are currently in the process of joining this room, then we # queue up events for later processing. - if pdu.room_id in self.room_queues: - logger.info("Ignoring PDU %s for room %s from %s for now; join " - "in progress", pdu.event_id, pdu.room_id, origin) - self.room_queues[pdu.room_id].append((pdu, origin)) + if room_id in self.room_queues: + logger.info( + "[%s %s] Queuing PDU from %s for now: join in progress", + room_id, event_id, origin, + ) + self.room_queues[room_id].append((pdu, origin)) return # If we're no longer in the room just ditch the event entirely. This @@ -179,7 +210,7 @@ class FederationHandler(BaseHandler): # we should check if we *are* in fact in the room. If we are then we # can magically rejoin the room. is_in_room = yield self.auth.check_host_in_room( - pdu.room_id, + room_id, self.server_name ) if not is_in_room: @@ -188,8 +219,8 @@ class FederationHandler(BaseHandler): ) if was_in_room: logger.info( - "Ignoring PDU %s for room %s from %s as we've left the room!", - pdu.event_id, pdu.room_id, origin, + "[%s %s] Ignoring PDU from %s as we've left the room", + room_id, event_id, origin, ) defer.returnValue(None) @@ -204,8 +235,8 @@ class FederationHandler(BaseHandler): ) logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth + "[%s %s] min_depth: %d", + room_id, event_id, min_depth, ) prevs = {e_id for e_id, _ in pdu.prev_events} @@ -218,17 +249,18 @@ class FederationHandler(BaseHandler): # send to the clients. pdu.internal_metadata.outlier = True elif min_depth and pdu.depth > min_depth: - if get_missing and prevs - seen: + missing_prevs = prevs - seen + if sent_to_us_directly and missing_prevs: # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( - "Acquiring lock for room %r to fetch %d missing events: %r...", - pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], + "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "Acquired lock for room %r to fetch %d missing events", - pdu.room_id, len(prevs - seen), + "[%s %s] Acquired room lock to fetch %d missing prev_events", + room_id, event_id, len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -241,68 +273,150 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( - "Found all missing prev events for %s", pdu.event_id + "[%s %s] Found all missing prev_events", + room_id, event_id, ) - elif prevs - seen: + elif missing_prevs: logger.info( - "Not fetching %d missing events for room %r,event %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, - list(prevs - seen)[:5], + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) - if sent_to_us_directly and prevs - seen: - # If they have sent it to us directly, and the server - # isn't telling us about the auth events that it's - # made a message referencing, we explode - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - elif prevs - seen: - # Calculate the state of the previous events, and - # de-conflict them to find the current state. - state_groups = [] + if prevs - seen: + # We've still not been able to get all of the prev_events for this event. + # + # In this case, we need to fall back to asking another server in the + # federation for the state at this event. That's ok provided we then + # resolve the state against other bits of the DAG before using it (which + # will ensure that you can't just take over a room by sending an event, + # withholding its prev_events, and declaring yourself to be an admin in + # the subsequent state request). + # + # Now, if we're pulling this event as a missing prev_event, then clearly + # this event is not going to become the only forward-extremity and we are + # guaranteed to resolve its state against our existing forward + # extremities, so that should be fine. + # + # On the other hand, if this event was pushed to us, it is possible for + # it to become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very bad. + # Further, if the event was pushed to us, there is no excuse for us not to + # have all the prev_events. We therefore reject any such events. + # + # XXX this really feels like it could/should be merged with the above, + # but there is an interaction with min_depth that I'm not really + # following. + + if sent_to_us_directly: + logger.warn( + "[%s %s] Rejecting: failed to fetch %d prev events: %s", + room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. auth_chains = set() + event_map = { + event_id: pdu, + } try: # Get the state of the events we know about - ours = yield self.store.get_state_groups(pdu.room_id, list(seen)) - state_groups.append(ours) + ours = yield self.store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours # Ask the remote server for the states we don't # know about for p in prevs - seen: - state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, pdu.room_id, p - ) - ) - auth_chains.update(got_auth_chain) - state_group = {(x.type, x.state_key): x.event_id for x in state} - state_groups.append(state_group) - - # Resolve any conflicting state - def fetch(ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False + logger.info( + "[%s %s] Requesting state at missing prev_event %s", + room_id, event_id, p, ) - state_map = yield resolve_events_with_factory( - state_groups, {pdu.event_id: pdu}, fetch + with logcontext.nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state, got_auth_chain = ( + yield self.federation_client.get_state_for_room( + origin, room_id, p, + ) + ) + + # we want the state *after* p; get_state_for_room returns the + # state *before* p. + remote_event = yield self.federation_client.get_pdu( + [origin], p, outlier=True, + ) + + if remote_event is None: + raise Exception( + "Unable to get missing prev_event %s" % (p, ) + ) + + if remote_event.is_state(): + remote_state.append(remote_event) + + # XXX hrm I'm not convinced that duplicate events will compare + # for equality, so I'm not sure this does what the author + # hoped. + auth_chains.update(got_auth_chain) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = yield self.store.get_room_version(room_id) + state_map = yield resolve_events_with_store( + room_version, state_maps, event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = yield self.store.get_events( + list(state_map.values()), + get_prev_content=False, + check_redacted=False, ) + event_map.update(evs) - state = (yield self.store.get_events(state_map.values())).values() + state = [ + event_map[e] for e in six.itervalues(state_map) + ] auth_chain = list(auth_chains) except Exception: + logger.warn( + "[%s %s] Error attempting to resolve state at missing " + "prev_events", + room_id, event_id, exc_info=True, + ) raise FederationError( "ERROR", 403, "We can't get valid state history.", - affected=pdu.event_id, + affected=event_id, ) yield self._process_received_pdu( @@ -321,15 +435,16 @@ class FederationHandler(BaseHandler): prevs (set(str)): List of event ids which we are missing min_depth (int): Minimum depth of events to return. """ - # We recalculate seen, since it may have changed. + + room_id = pdu.room_id + event_id = pdu.event_id + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room( - pdu.room_id - ) + latest = yield self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -337,8 +452,8 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "Missing %d events for room %r pdu %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5] + "[%s %s]: Requesting missing events between %s and %s", + room_id, event_id, shortstr(latest), event_id, ) # XXX: we set timeout to 10s to help workaround @@ -359,49 +474,88 @@ class FederationHandler(BaseHandler): # apparently. # # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out agressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the agressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timout to 60s and see what happens. missing_events = yield self.federation_client.get_missing_events( origin, - pdu.room_id, + room_id, earliest_events_ids=list(latest), latest_events=[pdu], limit=10, min_depth=min_depth, - timeout=10000, + timeout=60000, ) logger.info( - "Got %d events: %r...", - len(missing_events), [e.event_id for e in missing_events[:5]] + "[%s %s]: Got %d prev_events: %s", + room_id, event_id, len(missing_events), shortstr(missing_events), ) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) - for e in missing_events: - logger.info("Handling found event %s", e.event_id) - try: - yield self.on_receive_pdu( - origin, - e, - get_missing=False - ) - except FederationError as e: - if e.code == 403: - logger.warn("Event %s failed history check.") - else: - raise + for ev in missing_events: + logger.info( + "[%s %s] Handling received prev_event %s", + room_id, event_id, ev.event_id, + ) + with logcontext.nested_logging_context(ev.event_id): + try: + yield self.on_receive_pdu( + origin, + ev, + sent_to_us_directly=False, + ) + except FederationError as e: + if e.code == 403: + logger.warn( + "[%s %s] Received prev_event %s failed history check.", + room_id, event_id, ev.event_id, + ) + else: + raise - @log_function @defer.inlineCallbacks - def _process_received_pdu(self, origin, pdu, state, auth_chain): + def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ - event = pdu + room_id = event.room_id + event_id = event.event_id - logger.debug("Processing event: %s", event) + logger.debug( + "[%s %s] Processing event: %s", + room_id, event_id, event, + ) # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work @@ -410,15 +564,16 @@ class FederationHandler(BaseHandler): # event. if state and auth_chain and not event.internal_metadata.is_outlier(): is_in_room = yield self.auth.check_host_in_room( - event.room_id, + room_id, self.server_name ) else: is_in_room = True + if not is_in_room: logger.info( - "Got event for room we're not in: %r %r", - event.room_id, event.event_id + "[%s %s] Got event for room we're not in", + room_id, event_id, ) try: @@ -430,7 +585,7 @@ class FederationHandler(BaseHandler): "ERROR", e.code, e.msg, - affected=event.event_id, + affected=event_id, ) else: @@ -463,6 +618,10 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) + logger.info( + "[%s %s] persisting newly-received auth/state events %s", + room_id, event_id, [e["event"].event_id for e in event_infos] + ) yield self._handle_new_events(origin, event_infos) try: @@ -479,12 +638,12 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - room = yield self.store.get_room(event.room_id) + room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=event.room_id, + room_id=room_id, room_creator_user_id="", is_public=False, ) @@ -512,7 +671,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, event.room_id) + yield self.user_joined_room(user, room_id) @log_function @defer.inlineCallbacks @@ -593,7 +752,7 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + state_events.values() + auth_events.values() + for event in events + list(state_events.values()) + list(auth_events.values()) for a_id, _ in event.auth_events ) auth_events.update({ @@ -801,7 +960,7 @@ class FederationHandler(BaseHandler): ) continue except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: logger.info(e) @@ -1026,7 +1185,8 @@ class FederationHandler(BaseHandler): try: logger.info("Processing queued PDU %s which was received " "while we were joining %s", p.event_id, p.room_id) - yield self.on_receive_pdu(origin, p) + with logcontext.nested_logging_context(p.event_id): + yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( "Error handling queued PDU %s from %s: %s", @@ -1357,7 +1517,7 @@ class FederationHandler(BaseHandler): ) if state_groups: - _, state = state_groups.items().pop() + _, state = list(state_groups.items()).pop() results = state if event.is_state(): @@ -1429,12 +1589,10 @@ class FederationHandler(BaseHandler): else: defer.returnValue(None) - @log_function def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) @defer.inlineCallbacks - @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, backfilled=False): context = yield self._prep_event( @@ -1443,6 +1601,9 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: if not event.internal_metadata.is_outlier() and not backfilled: yield self.action_generator.handle_push_actions_for_event( @@ -1453,15 +1614,13 @@ class FederationHandler(BaseHandler): [(event, context)], backfilled=backfilled, ) - except: # noqa: E722, as we reraise the exception this is fine. - tp, value, tb = sys.exc_info() - - logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - - six.reraise(tp, value, tb) + success = True + finally: + if not success: + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) defer.returnValue(context) @@ -1474,15 +1633,22 @@ class FederationHandler(BaseHandler): Notifies about the events where appropriate. """ - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self._prep_event, + + @defer.inlineCallbacks + def prep(ev_info): + event = ev_info["event"] + with logcontext.nested_logging_context(suffix=event.event_id): + res = yield self._prep_event( origin, - ev_info["event"], + event, state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) + defer.returnValue(res) + + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) for ev_info in event_infos ], consumeErrors=True, )) @@ -1634,8 +1800,8 @@ class FederationHandler(BaseHandler): ) except AuthError as e: logger.warn( - "Rejecting %s because %s", - event.event_id, e.msg + "[%s %s] Rejecting: %s", + event.room_id, event.event_id, e.msg ) context.rejected = RejectedReason.AUTH_ERROR @@ -1686,7 +1852,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): in_room = yield self.auth.check_host_in_room( room_id, origin @@ -1695,14 +1861,12 @@ class FederationHandler(BaseHandler): raise AuthError(403, "Host not in room.") limit = min(limit, 20) - min_depth = max(min_depth, 0) missing_events = yield self.store.get_missing_events( room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, limit=limit, - min_depth=min_depth, ) missing_events = yield filter_events_for_server( @@ -1828,7 +1992,10 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state = self.state_handler.resolve_events( + room_version = yield self.store.get_room_version(event.room_id) + + new_state = yield self.state_handler.resolve_events( + room_version, [list(local_view.values()), list(remote_view.values())], event ) @@ -2353,7 +2520,7 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - self._notify_persisted_event(event, max_stream_id) + yield self._notify_persisted_event(event, max_stream_id) def _notify_persisted_event(self, event, max_stream_id): """Checks to see if notifier/pushers should be notified about the @@ -2386,7 +2553,7 @@ class FederationHandler(BaseHandler): extra_users=extra_users ) - self.pusher_pool.on_new_notifications( + return self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 53e5e2648b..173315af6c 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -20,7 +20,7 @@ from six import iteritems from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import HttpResponseException, SynapseError from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -37,9 +37,23 @@ def _create_rerouter(func_name): ) else: destination = get_domain_from_id(group_id) - return getattr(self.transport_client, func_name)( + d = getattr(self.transport_client, func_name)( destination, group_id, *args, **kwargs ) + + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. + def h(failure): + failure.trap(HttpResponseException) + e = failure.value + if e.code == 403: + raise e.to_synapse_error() + return failure + d.addErrback(h) + return d return f diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e009395207..563bb3cea3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler): room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( self.store.get_state_for_events, - [event.event_id], None, + [event.event_id], ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking): room_state = yield self.store.get_state_for_events( - [member_event_id], None + [member_event_id], ) room_state = room_state[member_event_id] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e484061cc0..969e588e73 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -14,9 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys -import six from six import iteritems, itervalues, string_types from canonicaljson import encode_canonical_json, json @@ -37,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -82,7 +81,7 @@ class MessageHandler(object): elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] + [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -90,7 +89,7 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, types=None, filtered_types=None, + self, user_id, room_id, state_filter=StateFilter.all(), at_token=None, is_guest=False, ): """Retrieve all state events for a given room. If the user is @@ -102,13 +101,8 @@ class MessageHandler(object): Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. at_token(StreamToken|None): the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current @@ -141,7 +135,7 @@ class MessageHandler(object): event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], types, filtered_types=filtered_types, + [event.event_id], state_filter=state_filter, ) room_state = room_state[event.event_id] else: @@ -160,12 +154,12 @@ class MessageHandler(object): if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, types, filtered_types=filtered_types, + room_id, state_filter=state_filter, ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], types, filtered_types=filtered_types, + [membership_event_id], state_filter=state_filter, ) room_state = room_state[membership_event_id] @@ -624,6 +618,9 @@ class EventCreationHandler(object): event, context ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: # If we're a worker we need to hit out to the master. if self.config.worker_app: @@ -636,6 +633,7 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + success = True return yield self.persist_and_notify_client_event( @@ -645,17 +643,16 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) - except: # noqa: E722, as we reraise the exception this is fine. - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - tp, value, tb = sys.exc_info() - - run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - six.reraise(tp, value, tb) + success = True + finally: + if not success: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) @defer.inlineCallbacks def persist_and_notify_client_event( @@ -778,7 +775,7 @@ class EventCreationHandler(object): event, context=context ) - self.pusher_pool.on_new_notifications( + yield self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5170d093e3..43f81bd607 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -21,6 +21,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background @@ -255,28 +256,19 @@ class PaginationHandler(object): if event_filter and event_filter.lazy_load_members(): # TODO: remove redundant members - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in events - ) - ] + # FIXME: we also care about invite targets etc. + state_filter = StateFilter.from_types( + (EventTypes.Member, event.sender) + for event in events + ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, types=types, + events[0].event_id, state_filter=state_filter, ) if state_ids: state = yield self.store.get_events(list(state_ids.values())) - - if state: - state = yield filter_events_for_client( - self.store, - user_id, - state.values(), - is_peeking=(member_event_id is None), - ) + state = state.values() time_now = self.clock.time_msec() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c3506b0db0..65e29518f0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -36,14 +36,18 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -class ProfileHandler(BaseHandler): - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 +class BaseProfileHandler(BaseHandler): + """Handles fetching and updating user profile information. + + BaseProfileHandler can be instantiated directly on workers and will + delegate to master when necessary. The master process should use the + subclass MasterProfileHandler + """ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000 def __init__(self, hs): - super(ProfileHandler, self).__init__(hs) + super(BaseProfileHandler, self).__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -222,10 +226,8 @@ class ProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get displayname") raise - except Exception: - logger.exception("Failed to get displayname") - else: - defer.returnValue(result["displayname"]) + + defer.returnValue(result["displayname"]) @defer.inlineCallbacks def set_displayname(self, target_user, requester, new_displayname, by_admin=False): @@ -315,8 +317,6 @@ class ProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get avatar_url") raise - except Exception: - logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -408,9 +408,23 @@ class ProfileHandler(BaseHandler): except Exception as e: logger.warn( "Failed to update join event for room %s - %s", - room_id, str(e.message) + room_id, str(e) ) + +class MasterProfileHandler(BaseProfileHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs): + super(MasterProfileHandler, self).__init__(hs) + + assert hs.config.worker_app is None + + self.clock.looping_call( + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + def _start_update_remote_profile_cache(self): return run_as_background_process( "Update remote profile", self._update_remote_profile_cache, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a6f3181f09..4c2690ba26 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -119,7 +119,7 @@ class ReceiptsHandler(BaseHandler): "receipt_key", max_batch_id, rooms=affected_room_ids ) # Note that the min here shouldn't be relied upon to be accurate. - self.hs.get_pusherpool().on_new_receipts( + yield self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3e061c89dc..757eacd214 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -50,6 +50,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self.room_creation_handler = self.hs.get_room_creation_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -126,6 +127,7 @@ class RegistrationHandler(BaseHandler): make_guest=False, display_name=None, admin=False, + threepid=None, ): """Registers a new client on the server. @@ -147,7 +149,7 @@ class RegistrationHandler(BaseHandler): RegistrationError if there was a problem registering. """ - yield self.auth.check_auth_blocking() + yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -232,9 +234,36 @@ class RegistrationHandler(BaseHandler): # auto-join the user to any rooms we're supposed to dump them into fake_requester = create_requester(user_id) + + # try to create the room if we're the first user on the server + should_auto_create_rooms = False + if self.hs.config.autocreate_auto_join_rooms: + count = yield self.store.count_all_users() + should_auto_create_rooms = count == 1 + for r in self.hs.config.auto_join_rooms: try: - yield self._join_user_to_room(fake_requester, r) + if should_auto_create_rooms: + room_alias = RoomAlias.from_string(r) + if self.hs.hostname != room_alias.domain: + logger.warning( + 'Cannot create room alias %s, ' + 'it does not match server domain', + r, + ) + else: + # create room expects the localpart of the room alias + room_alias_localpart = room_alias.localpart + yield self.room_creation_handler.create_room( + fake_requester, + config={ + "preset": "public_chat", + "room_alias_name": room_alias_localpart + }, + ratelimit=False, + ) + else: + yield self._join_user_to_room(fake_requester, r) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) @@ -552,4 +581,5 @@ class RegistrationHandler(BaseHandler): room_id=room_id, remote_room_hosts=remote_room_hosts, action="join", + ratelimit=False, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 621b91d2ae..3870e433ec 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from synapse.api.constants import ( RoomCreationPreset, ) from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils from synapse.visibility import filter_events_for_client @@ -192,10 +193,11 @@ class RoomCreationHandler(BaseHandler): if room_alias: directory_handler = self.hs.get_handlers().directory_handler yield directory_handler.create_association( - user_id=user_id, + requester=requester, room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], + send_event=False, ) preset_config = config.get( @@ -291,7 +293,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event( - requester, user_id, room_id + requester, room_id ) defer.returnValue(result) @@ -499,23 +501,24 @@ class RoomContextHandler(object): else: last_event_id = event_id - types = None - filtered_types = None if event_filter and event_filter.lazy_load_members(): - members = set(ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], - )) - filtered_types = [EventTypes.Member] - types = [(EventTypes.Member, member) for member in members] + state_filter = StateFilter.from_lazy_load_member_list( + ev.sender + for ev in itertools.chain( + results["events_before"], + (results["event"],), + results["events_after"], + ) + ) + else: + state_filter = StateFilter.all() # XXX: why do we return the state as of the last event rather than the # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], types, filtered_types=filtered_types, + [last_event_id], state_filter=state_filter, ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 37e41afd61..dc88620885 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from six import iteritems +from six import PY3, iteritems from six.moves import range import msgpack @@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ r for r in sorted_rooms - if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] total_room_count = len(rooms_to_scan) @@ -444,9 +444,16 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", ( @classmethod def from_token(cls, token): + if PY3: + # The argument raw=False is only available on new versions of + # msgpack, and only really needed on Python 3. Gate it behind + # a PY3 check to avoid causing issues on Debian-packaged versions. + decoded = msgpack.loads(decode_base64(token), raw=False) + else: + decoded = msgpack.loads(decode_base64(token)) return RoomListNextBatch(**{ cls.REVERSE_KEY_DICT[key]: val - for key, val in msgpack.loads(decode_base64(token)).items() + for key, val in decoded.items() }) def to_token(self): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fb94b5d7d4..07fd3e82fc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -344,6 +344,7 @@ class RoomMemberHandler(object): latest_event_ids = ( event_id for (event_id, _, _) in prev_events_and_hashes ) + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -582,6 +583,11 @@ class RoomMemberHandler(object): room_id = mapping["room_id"] servers = mapping["servers"] + # put the server which owns the alias at the front of the server list. + if room_alias.domain in servers: + servers.remove(room_alias.domain) + servers.insert(0, room_alias.domain) + defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index c464adbd0b..80e7b15de8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -54,7 +55,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch) + b = decode_base64(batch).decode('ascii') batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -258,18 +259,18 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( batch_group, batch_group_key, pagination_token - )) + )).encode('ascii')) else: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( "all", "", pagination_token - )) + )).encode('ascii')) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64("%s\n%s\n%s" % ( + group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( "room_id", room_id, pagination_token - )) + )).encode('ascii')) allowed_events.extend(room_events) @@ -324,9 +325,12 @@ class SearchHandler(BaseHandler): else: last_event_id = event.event_id + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] + ) + state = yield self.store.get_state_for_event( - last_event_id, - types=[(EventTypes.Member, sender) for sender in senders] + last_event_id, state_filter ) res["profile_info"] = { diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 113d711d4c..c802db9792 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,10 +20,14 @@ import logging from six import iteritems, itervalues +from prometheus_client import Counter + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user +from synapse.storage.roommember import MemberSummary +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -35,6 +39,19 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) + +# Counts the number of times we returned a non-empty sync. `type` is one of +# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is +# "true" or "false" depending on if the request asked for lazy loaded members or +# not. +non_empty_sync_counter = Counter( + "synapse_handlers_sync_nonempty_total", + "Count of non empty sync responses. type is initial_sync/full_state_sync" + "/incremental_sync. lazy_loaded indicates if lazy loaded members were " + "enabled for that request.", + ["type", "lazy_loaded"], +) + # Store the cache that tracks which lazy-loaded members have been sent to a given # client for no more than 30 minutes. LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 @@ -226,14 +243,16 @@ class SyncHandler(object): @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + if since_token is None: + sync_type = "initial_sync" + elif full_state: + sync_type = "full_state_sync" + else: + sync_type = "incremental_sync" + context = LoggingContext.current_context() if context: - if since_token is None: - context.tag = "initial_sync" - elif full_state: - context.tag = "full_state_sync" - else: - context.tag = "incremental_sync" + context.tag = sync_type if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling @@ -241,7 +260,6 @@ class SyncHandler(object): result = yield self.current_sync_for_user( sync_config, since_token, full_state=full_state, ) - defer.returnValue(result) else: def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) @@ -250,7 +268,15 @@ class SyncHandler(object): sync_config.user.to_string(), timeout, current_sync_callback, from_token=since_token, ) - defer.returnValue(result) + + if result: + if sync_config.filter_collection.lazy_load_members(): + lazy_loaded = "true" + else: + lazy_loaded = "false" + non_empty_sync_counter.labels(sync_type, lazy_loaded).inc() + + defer.returnValue(result) def current_sync_for_user(self, sync_config, since_token=None, full_state=False): @@ -444,25 +470,20 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event, types=None, filtered_types=None): + def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event Args: event(synapse.events.EventBase): event of interest - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, types, filtered_types=filtered_types, + event.event_id, state_filter=state_filter, ) if event.is_state(): state_ids = state_ids.copy() @@ -470,18 +491,14 @@ class SyncHandler(object): defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): + def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): """ Get the room state at a particular stream position Args: room_id(str): room for which to get state stream_position(StreamToken): point at which to get state - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) @@ -497,7 +514,7 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, types, filtered_types=filtered_types, + last_event, state_filter=state_filter, ) else: @@ -525,6 +542,8 @@ class SyncHandler(object): A deferred dict describing the room summary """ + # FIXME: we could/should get this from room_stats when matthew/stats lands + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1, @@ -536,45 +555,69 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( - last_event.event_id, [ - (EventTypes.Member, None), + last_event.event_id, + state_filter=StateFilter.from_types([ (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), - ] + ]), ) - member_ids = { - state_key: event_id - for (t, state_key), event_id in state_ids.iteritems() - if t == EventTypes.Member - } + # this is heavily cached, thus: fast. + details = yield self.store.get_room_summary(room_id) + name_id = state_ids.get((EventTypes.Name, '')) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) summary = {} - - # FIXME: it feels very heavy to load up every single membership event - # just to calculate the counts. - member_events = yield self.store.get_events(member_ids.values()) - - joined_user_ids = [] - invited_user_ids = [] - - for ev in member_events.values(): - if ev.content.get("membership") == Membership.JOIN: - joined_user_ids.append(ev.state_key) - elif ev.content.get("membership") == Membership.INVITE: - invited_user_ids.append(ev.state_key) + empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = len(joined_user_ids) - summary["m.invited_member_count"] = len(invited_user_ids) + summary["m.joined_member_count"] = ( + details.get(Membership.JOIN, empty_ms).count + ) + summary["m.invited_member_count"] = ( + details.get(Membership.INVITE, empty_ms).count + ) - if name_id or canonical_alias_id: - defer.returnValue(summary) + # if the room has a name or canonical_alias set, we can skip + # calculating heroes. we assume that if the event has contents, it'll + # be a valid name or canonical_alias - i.e. we're checking that they + # haven't been "deleted" by blatting {} over the top. + if name_id: + name = yield self.store.get_event(name_id, allow_none=True) + if name and name.content: + defer.returnValue(summary) + + if canonical_alias_id: + canonical_alias = yield self.store.get_event( + canonical_alias_id, allow_none=True, + ) + if canonical_alias and canonical_alias.content: + defer.returnValue(summary) - # FIXME: order by stream ordering, not alphabetic + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members + ] + gone_user_ids = ( + [r[0] for r in details.get(Membership.LEAVE, empty_ms).members] + + [r[0] for r in details.get(Membership.BAN, empty_ms).members] + ) + # FIXME: only build up a member_ids list for our heroes + member_ids = {} + for membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.LEAVE, + Membership.BAN + ): + for user_id, event_id in details.get(membership, empty_ms).members: + member_ids[user_id] = event_id + + # FIXME: order by stream ordering rather than as returned by SQL me = sync_config.user.to_string() if (joined_user_ids or invited_user_ids): summary['m.heroes'] = sorted( @@ -586,7 +629,11 @@ class SyncHandler(object): )[0:5] else: summary['m.heroes'] = sorted( - [user_id for user_id in member_ids.keys() if user_id != me] + [ + user_id + for user_id in gone_user_ids + if user_id != me + ] )[0:5] if not sync_config.filter_collection.lazy_load_members(): @@ -663,8 +710,7 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): - types = None - filtered_types = None + members_to_fetch = None lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -675,16 +721,21 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - ) - ] + members_to_fetch = set( + event.sender # FIXME: we also care about invite targets etc. + for event in batch.events + ) - # only apply the filtering to room members - filtered_types = [EventTypes.Member] + if full_state: + # always make sure we LL ourselves so we know we're in the room + # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in #3840. + members_to_fetch.add(sync_config.user.to_string()) + + state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + else: + state_filter = StateFilter.all() timeline_state = { (event.type, event.state_key): event.event_id @@ -694,19 +745,17 @@ class SyncHandler(object): if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=now_token, + state_filter=state_filter, ) state_ids = current_state_ids @@ -719,19 +768,31 @@ class SyncHandler(object): lazy_load_members=lazy_load_members, ) elif batch.limited: - state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, types=types, - filtered_types=filtered_types, + state_at_timeline_start = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter, ) - current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + # for now, we disable LL for gappy syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way + # more state in the server than if we were LLing. + # + # We still have to filter timeline_start to LL entries (above) in order + # for _calculate_state's LL logic to work, as we have to include LL + # members for timeline senders in case they weren't loaded in the initial + # sync. We do this by (counterintuitively) by filtering timeline_start + # members to just be ones which were timeline senders, which then ensures + # all of the rest get included in the state block (if we need to know + # about them). + state_filter = StateFilter.all() + + state_at_previous_sync = yield self.get_state_at( + room_id, stream_position=since_token, + state_filter=state_filter, ) - state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + current_state_ids = yield self.store.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = _calculate_state( @@ -739,15 +800,28 @@ class SyncHandler(object): timeline_start=state_at_timeline_start, previous=state_at_previous_sync, current=current_state_ids, + # we have to include LL members in case LL initial sync missed them lazy_load_members=lazy_load_members, ) else: state_ids = {} if lazy_load_members: - if types: + if members_to_fetch and batch.events: + # We're returning an incremental sync, with no + # "gap" since the previous sync, so normally there would be + # no state to return. + # But we're lazy-loading, so the client might need some more + # member events to understand the events in this timeline. + # So we fish out all the member events corresponding to the + # timeline here, and then dedupe any redundant ones below. + state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), ) if lazy_load_members and not include_redundant_members: @@ -767,7 +841,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in state_ids.iteritems() + for t, event_id in iteritems(state_ids) if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -1574,6 +1648,19 @@ class SyncHandler(object): newly_joined_room=newly_joined, ) + # When we join the room (or the client requests full_state), we should + # send down any existing tags. Usually the user won't have tags in a + # newly joined room, unless either a) they've joined before or b) the + # tag was added by synapse e.g. for server notice rooms. + if full_state: + user_id = sync_result_builder.sync_config.user.to_string() + tags = yield self.store.get_tags_for_room(user_id, room_id) + + # If there aren't any tags, don't send the empty tags list down + # sync + if not tags: + tags = None + account_data_events = [] if tags is not None: account_data_events.append({ @@ -1602,10 +1689,24 @@ class SyncHandler(object): ) summary = {} + + # we include a summary in room responses when we're lazy loading + # members (as the client otherwise doesn't have enough info to form + # the name itself). if ( sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. any(ev.type == EventTypes.Member for ev in batch.events) or + ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited and + any(t == EventTypes.Member for (t, k) in state) + ) or since_token is None ) ): @@ -1635,6 +1736,16 @@ class SyncHandler(object): unread_notifications["highlight_count"] = notifs["highlight_count"] sync_result_builder.joined.append(room_sync) + + if batch.limited and since_token: + user_id = sync_result_builder.sync_config.user.to_string() + logger.info( + "Incremental gappy sync of %s for user %s with %d state events" % ( + room_id, + user_id, + len(state), + ) + ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( room_id=room_id, @@ -1728,17 +1839,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - timeline_contains.items(), - previous.items(), - timeline_start.items(), - current.items(), + iteritems(timeline_contains), + iteritems(previous), + iteritems(timeline_start), + iteritems(current), ) } - c_ids = set(e for e in current.values()) - ts_ids = set(e for e in timeline_start.values()) - p_ids = set(e for e in previous.values()) - tc_ids = set(e for e in timeline_contains.values()) + c_ids = set(e for e in itervalues(current)) + ts_ids = set(e for e in itervalues(timeline_start)) + p_ids = set(e for e in itervalues(previous)) + tc_ids = set(e for e in itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -1752,7 +1863,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in timeline_start.iteritems() + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 2d2d3d5a0d..c610933dd4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID, get_domain_from_id +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -68,6 +69,11 @@ class TypingHandler(object): # map room IDs to sets of users currently typing self._room_typing = {} + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial, + ) + self.clock.looping_call( self._handle_timeouts, 5000, @@ -218,6 +224,7 @@ class TypingHandler(object): for domain in set(get_domain_from_id(u) for u in users): if domain != self.server_name: + logger.debug("sending typing update to %s", domain) self.federation.send_edu( destination=domain, edu_type="m.typing", @@ -274,19 +281,29 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial + self._typing_stream_change_cache.entity_has_changed( + member.room_id, self._latest_room_serial, + ) self.notifier.on_new_event( "typing_key", self._latest_room_serial, rooms=[member.room_id] ) def get_all_typing_updates(self, last_id, current_id): - # TODO: Work out a way to do this without scanning the entire state. if last_id == current_id: return [] + changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( + last_id, + ) + + if changed_rooms is None: + changed_rooms = self._room_serials + rows = [] - for room_id, serial in self._room_serials.items(): - if last_id < serial and serial <= current_id: + for room_id in changed_rooms: + serial = self._room_serials[room_id] + if last_id < serial <= current_id: typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 37dda64587..f11b430126 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -20,6 +20,7 @@ from six import iteritems from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure @@ -98,7 +99,6 @@ class UserDirectoryHandler(object): """ return self.store.search_user_dir(user_id, search_term, limit) - @defer.inlineCallbacks def notify_new_event(self): """Called when there may be more deltas to process """ @@ -108,17 +108,23 @@ class UserDirectoryHandler(object): if self._is_processing: return + @defer.inlineCallbacks + def process(): + try: + yield self._unsafe_process() + finally: + self._is_processing = False + self._is_processing = True - try: - yield self._unsafe_process() - finally: - self._is_processing = False + run_as_background_process("user_directory.notify_new_event", process) @defer.inlineCallbacks def handle_local_profile_change(self, user_id, profile): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url, None, ) @@ -127,6 +133,8 @@ class UserDirectoryHandler(object): def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.remove_from_user_dir(user_id) yield self.store.remove_from_user_in_public_room(user_id) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 58ef8d3ce4..a3f9e4f67c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -38,12 +38,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') def redact_uri(uri): """Strips access tokens from the uri replaces with <redacted>""" return ACCESS_TOKEN_RE.sub( - br'\1<redacted>\3', + r'\1<redacted>\3', uri ) diff --git a/synapse/http/client.py b/synapse/http/client.py index ab4fbf59b2..3d05f83b8c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,24 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging -import urllib -from six import StringIO +from six import text_type +from six.moves import urllib +import treq from canonicaljson import encode_canonical_json, json from prometheus_client import Counter from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl, task +from twisted.internet import defer, protocol, reactor, ssl from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone from twisted.web.client import ( Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, - FileBodyProducer as TwistedFileBodyProducer, GzipDecoder, HTTPConnectionPool, PartialDownloadError, @@ -42,7 +43,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async_helpers import add_timeout_to_deferred +from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable @@ -83,8 +84,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + self.user_agent = self.user_agent.encode('ascii') + @defer.inlineCallbacks - def request(self, method, uri, *args, **kwargs): + def request(self, method, uri, data=b'', headers=None): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -93,10 +96,10 @@ class SimpleHttpClient(object): logger.info("Sending request %s %s", method, redact_uri(uri)) try: - request_deferred = self.agent.request( - method, uri, *args, **kwargs + request_deferred = treq.request( + method, uri, agent=self.agent, data=data, headers=headers ) - add_timeout_to_deferred( + request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) @@ -112,7 +115,7 @@ class SimpleHttpClient(object): incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.message + method, redact_uri(uri), type(e).__name__, e.args[0] ) raise @@ -137,7 +140,8 @@ class SimpleHttpClient(object): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode( + encode_urlencode_args(args), True).encode("utf8") actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -148,15 +152,14 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(query_bytes)) + data=query_bytes ) - body = yield make_deferred_yieldable(readBody(response)) - if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + body = yield make_deferred_yieldable(treq.json_content(response)) + defer.returnValue(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -191,9 +194,9 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -248,7 +251,7 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) json_str = encode_canonical_json(json_body) @@ -262,9 +265,9 @@ class SimpleHttpClient(object): response = yield self.request( "PUT", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -293,7 +296,7 @@ class SimpleHttpClient(object): HttpResponseException on a non-2xx HTTP response. """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) actual_headers = { @@ -304,7 +307,7 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), ) @@ -339,13 +342,14 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - url.encode("ascii"), + url, headers=Headers(actual_headers), ) resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: + if (b'Content-Length' in resp_headers and + int(resp_headers[b'Content-Length']) > max_size): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -378,7 +382,12 @@ class SimpleHttpClient(object): ) defer.returnValue( - (length, resp_headers, response.request.absoluteURI, response.code), + ( + length, + resp_headers, + response.request.absoluteURI.decode('ascii'), + response.code, + ), ) @@ -434,12 +443,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): @defer.inlineCallbacks def post_urlencoded_get_raw(self, url, args={}): - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True) response = yield self.request( "POST", - url.encode("ascii"), - bodyProducer=FileBodyProducer(StringIO(query_bytes)), + url, + data=query_bytes, headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], @@ -463,9 +472,9 @@ class SpiderEndpointFactory(object): def endpointForURI(self, uri): logger.info("Getting endpoint for %s", uri.toBytes()) - if uri.scheme == "http": + if uri.scheme == b"http": endpoint_factory = HostnameEndpoint - elif uri.scheme == "https": + elif uri.scheme == b"https": tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) def endpoint_factory(reactor, host, port, **kw): @@ -510,7 +519,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, unicode): + if isinstance(arg, text_type): return arg.encode('utf-8') elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] @@ -542,26 +551,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): def creatorForNetloc(self, hostname, port): return self - - -class FileBodyProducer(TwistedFileBodyProducer): - """Workaround for https://twistedmatrix.com/trac/ticket/8473 - - We override the pauseProducing and resumeProducing methods in twisted's - FileBodyProducer so that they do not raise exceptions if the task has - already completed. - """ - - def pauseProducing(self): - try: - super(FileBodyProducer, self).pauseProducing() - except task.TaskDone: - # task has already completed - pass - - def resumeProducing(self): - try: - super(FileBodyProducer, self).resumeProducing() - except task.NotPaused: - # task was not paused (probably because it had already completed) - pass diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index b0c9369519..91025037a3 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -108,7 +108,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= Args: reactor: Twisted reactor. - destination (bytes): The name of the server to connect to. + destination (unicode): The name of the server to connect to. tls_client_options_factory (synapse.crypto.context_factory.ClientTLSOptionsFactory): Factory which generates TLS options for client connections. @@ -126,10 +126,17 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= transport_endpoint = HostnameEndpoint default_port = 8008 else: + # the SNI string should be the same as the Host header, minus the port. + # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777, + # the Host header and SNI should therefore be the server_name of the remote + # server. + tls_options = tls_client_options_factory.get_options(domain) + def transport_endpoint(reactor, host, port, timeout): return wrapClientTLS( - tls_client_options_factory.get_options(host), - HostnameEndpoint(reactor, host, port, timeout=timeout)) + tls_options, + HostnameEndpoint(reactor, host, port, timeout=timeout), + ) default_port = 8448 if port is None: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 44b61e70a4..24b6110c20 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,19 +17,22 @@ import cgi import logging import random import sys -import urllib +from io import BytesIO -from six import string_types -from six.moves.urllib import parse as urlparse +from six import PY3, string_types +from six.moves import urllib -from canonicaljson import encode_canonical_json, json +import attr +import treq +from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from twisted.internet import defer, protocol, reactor +from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError +from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool from twisted.web.http_headers import Headers import synapse.metrics @@ -40,14 +43,12 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) -from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util import logcontext -from synapse.util.async_helpers import add_timeout_to_deferred +from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outbound_logger = logging.getLogger("synapse.http.outbound") outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", "", ["method"]) @@ -58,20 +59,119 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 +if PY3: + MAXINT = sys.maxsize +else: + MAXINT = sys.maxint + class MatrixFederationEndpointFactory(object): def __init__(self, hs): + self.reactor = hs.get_reactor() self.tls_client_options_factory = hs.tls_client_options_factory def endpointForURI(self, uri): - destination = uri.netloc + destination = uri.netloc.decode('ascii') return matrix_federation_endpoint( - reactor, destination, timeout=10, + self.reactor, destination, timeout=10, tls_client_options_factory=self.tls_client_options_factory ) +_next_id = 1 + + +@attr.s +class MatrixFederationRequest(object): + method = attr.ib() + """HTTP method + :type: str + """ + + path = attr.ib() + """HTTP path + :type: str + """ + + destination = attr.ib() + """The remote server to send the HTTP request to. + :type: str""" + + json = attr.ib(default=None) + """JSON to send in the body. + :type: dict|None + """ + + json_callback = attr.ib(default=None) + """A callback to generate the JSON. + :type: func|None + """ + + query = attr.ib(default=None) + """Query arguments. + :type: dict|None + """ + + txn_id = attr.ib(default=None) + """Unique ID for this request (for logging) + :type: str|None + """ + + def __attrs_post_init__(self): + global _next_id + self.txn_id = "%s-O-%s" % (self.method, _next_id) + _next_id = (_next_id + 1) % (MAXINT - 1) + + def get_json(self): + if self.json_callback: + return self.json_callback() + return self.json + + +@defer.inlineCallbacks +def _handle_json_response(reactor, timeout_sec, request, response): + """ + Reads the JSON body of a response, with a timeout + + Args: + reactor (IReactor): twisted reactor, for the timeout + timeout_sec (float): number of seconds to wait for response to complete + request (MatrixFederationRequest): the request that triggered the response + response (IResponse): response to the request + + Returns: + dict: parsed JSON response + """ + try: + check_content_type_is_json(response.headers) + + d = treq.json_content(response) + d = timeout_deferred( + d, + timeout=timeout_sec, + reactor=reactor, + ) + + body = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) + raise + logger.info( + "{%s} [%s] Completed: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) + defer.returnValue(body) + + class MatrixFederationHttpClient(object): """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. @@ -85,7 +185,9 @@ class MatrixFederationHttpClient(object): self.hs = hs self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname + reactor = hs.get_reactor() pool = HTTPConnectionPool(reactor) + pool.retryAutomatically = False pool.maxPersistentPerHost = 5 pool.cachedConnectionTimeout = 2 * 60 self.agent = Agent.usingEndpointFactory( @@ -93,34 +195,42 @@ class MatrixFederationHttpClient(object): ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string = hs.version_string - self._next_id = 1 + self.version_string_bytes = hs.version_string.encode('ascii') + self.default_timeout = 60 - def _create_url(self, destination, path_bytes, param_bytes, query_bytes): - return urlparse.urlunparse( - ("matrix", destination, path_bytes, param_bytes, query_bytes, "") - ) + def schedule(x): + reactor.callLater(_EPSILON, x) + + self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks - def _request(self, destination, method, path, - body_callback, headers_dict={}, param_bytes=b"", - query_bytes=b"", retry_on_dns_fail=True, - timeout=None, long_retries=False, - ignore_backoff=False, - backoff_on_404=False): - """ Creates and sends a request to the given server + def _send_request( + self, + request, + retry_on_dns_fail=True, + timeout=None, + long_retries=False, + ignore_backoff=False, + backoff_on_404=False + ): + """ + Sends a request to the given server. + Args: - destination (str): The remote server to send the HTTP request to. - method (str): HTTP method - path (str): The HTTP path + request (MatrixFederationRequest): details of request to be sent + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server). 60s by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + backoff_on_404 (bool): Back off if we get a 404 Returns: Deferred: resolves with the http response object on success. - Fails with ``HTTPRequestException``: if we get an HTTP response + Fails with ``HttpResponseException``: if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -132,38 +242,39 @@ class MatrixFederationHttpClient(object): (May also fail with plenty of other Exceptions for things like DNS failures, connection failures, SSL failures.) """ + if timeout: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + if ( - self.hs.config.federation_domain_whitelist and - destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation_domain_whitelist is not None and + request.destination not in self.hs.config.federation_domain_whitelist ): - raise FederationDeniedError(destination) + raise FederationDeniedError(request.destination) limiter = yield synapse.util.retryutils.get_retry_limiter( - destination, + request.destination, self.clock, self._store, backoff_on_404=backoff_on_404, ignore_backoff=ignore_backoff, ) - destination = destination.encode("ascii") - path_bytes = path.encode("ascii") - with limiter: - headers_dict[b"User-Agent"] = [self.version_string] - headers_dict[b"Host"] = [destination] - - url_bytes = self._create_url( - destination, path_bytes, param_bytes, query_bytes - ) - - txn_id = "%s-O-%s" % (method, self._next_id) - self._next_id = (self._next_id + 1) % (sys.maxint - 1) + method_bytes = request.method.encode("ascii") + destination_bytes = request.destination.encode("ascii") + path_bytes = request.path.encode("ascii") + if request.query: + query_bytes = encode_query_args(request.query) + else: + query_bytes = b"" - outbound_logger.info( - "{%s} [%s] Sending request: %s %s", - txn_id, destination, method, url_bytes - ) + headers_dict = { + b"User-Agent": [self.version_string_bytes], + b"Host": [destination_bytes], + } + with limiter: # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) if long_retries: @@ -171,88 +282,120 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - http_url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "") - ) + url_bytes = urllib.parse.urlunparse(( + b"matrix", destination_bytes, + path_bytes, None, query_bytes, b"", + )) + url_str = url_bytes.decode('ascii') + + url_to_sign_bytes = urllib.parse.urlunparse(( + b"", b"", + path_bytes, None, query_bytes, b"", + )) - log_result = None - try: - while True: - producer = None - if body_callback: - producer = body_callback(method, http_url_bytes, headers_dict) - - try: - request_deferred = self.agent.request( - method, - url_bytes, - Headers(headers_dict), - producer + while True: + try: + json = request.get_json() + if json: + headers_dict[b"Content-Type"] = [b"application/json"] + self.sign_request( + destination_bytes, method_bytes, url_to_sign_bytes, + headers_dict, json, ) - add_timeout_to_deferred( - request_deferred, - timeout / 1000. if timeout else 60, - self.hs.get_reactor(), - cancelled_to_request_timed_out_error, + data = encode_canonical_json(json) + producer = FileBodyProducer( + BytesIO(data), + cooperator=self._cooperator, + ) + else: + producer = None + self.sign_request( + destination_bytes, method_bytes, url_to_sign_bytes, + headers_dict, ) + + logger.info( + "{%s} [%s] Sending request: %s %s", + request.txn_id, request.destination, request.method, + url_str, + ) + + # we don't want all the fancy cookie and redirect handling that + # treq.request gives: just use the raw Agent. + request_deferred = self.agent.request( + method_bytes, + url_bytes, + headers=Headers(headers_dict), + bodyProducer=producer, + ) + + request_deferred = timeout_deferred( + request_deferred, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + + with Measure(self.clock, "outbound_request"): response = yield make_deferred_yieldable( request_deferred, ) - log_result = "%d %s" % (response.code, response.phrase,) - break - except Exception as e: - if not retry_on_dns_fail and isinstance(e, DNSLookupError): - logger.warn( - "DNS Lookup failed to %s with %s", - destination, - e - ) - log_result = "DNS Lookup failed to %s with %s" % ( - destination, e - ) - raise - - logger.warn( - "{%s} Sending request failed to %s: %s %s: %s", - txn_id, - destination, - method, - url_bytes, - _flatten_response_never_received(e), + break + except Exception as e: + logger.warn( + "{%s} [%s] Request failed: %s %s: %s", + request.txn_id, + request.destination, + request.method, + url_str, + _flatten_response_never_received(e), + ) + + if not retry_on_dns_fail and isinstance(e, DNSLookupError): + raise + + if retries_left and not timeout: + if long_retries: + delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) + delay = min(delay, 60) + delay *= random.uniform(0.8, 1.4) + else: + delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) + delay = min(delay, 2) + delay *= random.uniform(0.8, 1.4) + + logger.debug( + "{%s} [%s] Waiting %ss before re-sending...", + request.txn_id, + request.destination, + delay, ) - log_result = _flatten_response_never_received(e) - - if retries_left and not timeout: - if long_retries: - delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) - delay = min(delay, 60) - delay *= random.uniform(0.8, 1.4) - else: - delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) - delay = min(delay, 2) - delay *= random.uniform(0.8, 1.4) - - yield self.clock.sleep(delay) - retries_left -= 1 - else: - raise - finally: - outbound_logger.info( - "{%s} [%s] Result: %s", - txn_id, - destination, - log_result, - ) + yield self.clock.sleep(delay) + retries_left -= 1 + else: + raise + + logger.info( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) if 200 <= response.code < 300: pass else: # :'( # Update transactions table? - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + d = treq.content(response) + d = timeout_deferred( + d, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + body = yield make_deferred_yieldable(d) raise HttpResponseException( response.code, response.phrase, body ) @@ -269,8 +412,9 @@ class MatrixFederationHttpClient(object): destination_is must be non-None. method (bytes): The HTTP method of the request url_bytes (bytes): The URI path of the request - headers_dict (dict): Dictionary of request headers to append to - content (bytes): The body of the request + headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to + append to + content (object): The body of the request destination_is (bytes): As 'destination', but if the destination is an identity server @@ -297,11 +441,11 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(bytes( + auth_headers.append(( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( self.server_name, key, sig, - ) - )) + )).encode('ascii') + ) headers_dict[b"Authorization"] = auth_headers @@ -336,7 +480,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -346,38 +490,27 @@ class MatrixFederationHttpClient(object): is not on our federation whitelist """ - if not json_data_callback: - def json_data_callback(): - return data + request = MatrixFederationRequest( + method="PUT", + destination=destination, + path=path, + query=args, + json_callback=json_data_callback, + json=data, + ) - def body_callback(method, url_bytes, headers_dict): - json_data = json_data_callback() - self.sign_request( - destination, method, url_bytes, headers_dict, json_data - ) - producer = _JsonProducer(json_data) - return producer - - response = yield self._request( - destination, - "PUT", - path, - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, - query_bytes=encode_query_args(args), + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, backoff_on_404=backoff_on_404, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, @@ -401,7 +534,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -411,32 +544,30 @@ class MatrixFederationHttpClient(object): is not on our federation whitelist """ - def body_callback(method, url_bytes, headers_dict): - self.sign_request( - destination, method, url_bytes, headers_dict, data - ) - return _JsonProducer(data) - - response = yield self._request( - destination, - "POST", - path, - query_bytes=encode_query_args(args), - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, + request = MatrixFederationRequest( + method="POST", + destination=destination, + path=path, + query=args, + json=data, + ) + + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + if timeout: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), _sec_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, @@ -458,7 +589,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -471,29 +602,24 @@ class MatrixFederationHttpClient(object): logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, + query=args, + ) - response = yield self._request( - destination, - "GET", - path, - query_bytes=encode_query_args(args), - body_callback=body_callback, + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def delete_json(self, destination, path, long_retries=False, @@ -514,7 +640,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -523,26 +649,24 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ + request = MatrixFederationRequest( + method="DELETE", + destination=destination, + path=path, + query=args, + ) - response = yield self._request( - destination, - "DELETE", - path, - query_bytes=encode_query_args(args), - headers_dict={"Content-Type": ["application/json"]}, + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, @@ -560,7 +684,7 @@ class MatrixFederationHttpClient(object): Deferred: resolves with an (int,dict) tuple of the file length and a dict of the response headers. - Fails with ``HTTPRequestException`` if we get an HTTP response code + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300 Fails with ``NotRetryingDestination`` if we are not yet ready @@ -569,26 +693,15 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, + query=args, + ) - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, string_types): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) - logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) - - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - - response = yield self._request( - destination, - "GET", - path, - query_bytes=query_bytes, - body_callback=body_callback, + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -596,14 +709,25 @@ class MatrixFederationHttpClient(object): headers = dict(response.headers.getAllRawHeaders()) try: - with logcontext.PreserveLoggingContext(): - length = yield _readBodyToFile( - response, output_stream, max_size - ) - except Exception: - logger.exception("Failed to download body") + d = _readBodyToFile(response, output_stream, max_size) + d.addTimeout(self.default_timeout, self.hs.get_reactor()) + length = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) raise - + logger.info( + "{%s} [%s] Completed: %d %s [%d bytes]", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + length, + ) defer.returnValue((length, headers)) @@ -639,30 +763,6 @@ def _readBodyToFile(response, stream, max_size): return d -class _JsonProducer(object): - """ Used by the twisted http client to create the HTTP body from json - """ - def __init__(self, jsn): - self.reset(jsn) - - def reset(self, jsn): - self.body = encode_canonical_json(jsn) - self.length = len(self.body) - - def startProducing(self, consumer): - consumer.write(self.body) - return defer.succeed(None) - - def pauseProducing(self): - pass - - def stopProducing(self): - pass - - def resumeProducing(self): - pass - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -693,7 +793,7 @@ def check_content_type_is_json(headers): "No Content-Type header" ) - c_type = c_type[0] # only the first header + c_type = c_type[0].decode('ascii') # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RuntimeError( @@ -711,6 +811,6 @@ def encode_query_args(args): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] - query_bytes = urllib.urlencode(encoded_args, True) + query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes + return query_bytes.encode('utf8') diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 588e280571..62045a918b 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import threading from prometheus_client.core import Counter, Histogram @@ -38,7 +39,8 @@ outgoing_responses_counter = Counter( ) response_timer = Histogram( - "synapse_http_server_response_time_seconds", "sec", + "synapse_http_server_response_time_seconds", + "sec", ["method", "servlet", "tag", "code"], ) @@ -78,15 +80,11 @@ response_size = Counter( # than when the response was written. in_flight_requests_ru_utime = Counter( - "synapse_http_server_in_flight_requests_ru_utime_seconds", - "", - ["method", "servlet"], + "synapse_http_server_in_flight_requests_ru_utime_seconds", "", ["method", "servlet"] ) in_flight_requests_ru_stime = Counter( - "synapse_http_server_in_flight_requests_ru_stime_seconds", - "", - ["method", "servlet"], + "synapse_http_server_in_flight_requests_ru_stime_seconds", "", ["method", "servlet"] ) in_flight_requests_db_txn_count = Counter( @@ -111,6 +109,9 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() +# Protects the _in_flight_requests set from concurrent accesss +_in_flight_requests_lock = threading.Lock() + def _get_in_flight_counts(): """Returns a count of all in flight requests by (method, server_name) @@ -120,7 +121,8 @@ def _get_in_flight_counts(): """ # Cast to a list to prevent it changing while the Prometheus # thread is collecting metrics - reqs = list(_in_flight_requests) + with _in_flight_requests_lock: + reqs = list(_in_flight_requests) for rm in reqs: rm.update_metrics() @@ -129,7 +131,7 @@ def _get_in_flight_counts(): # type counts = {} for rm in reqs: - key = (rm.method, rm.name,) + key = (rm.method, rm.name) counts[key] = counts.get(key, 0) + 1 return counts @@ -154,10 +156,12 @@ class RequestMetrics(object): # to the "in flight" metrics. self._request_stats = self.start_context.get_resource_usage() - _in_flight_requests.add(self) + with _in_flight_requests_lock: + _in_flight_requests.add(self) - def stop(self, time_sec, request): - _in_flight_requests.discard(self) + def stop(self, time_sec, response_code, sent_bytes): + with _in_flight_requests_lock: + _in_flight_requests.discard(self) context = LoggingContext.current_context() @@ -168,39 +172,40 @@ class RequestMetrics(object): if context != self.start_context: logger.warn( "Context have unexpectedly changed %r, %r", - context, self.start_context + context, + self.start_context, ) return - response_code = str(request.code) + response_code = str(response_code) - outgoing_responses_counter.labels(request.method, response_code).inc() + outgoing_responses_counter.labels(self.method, response_code).inc() - response_count.labels(request.method, self.name, tag).inc() + response_count.labels(self.method, self.name, tag).inc() - response_timer.labels(request.method, self.name, tag, response_code).observe( + response_timer.labels(self.method, self.name, tag, response_code).observe( time_sec - self.start ) resource_usage = context.get_resource_usage() - response_ru_utime.labels(request.method, self.name, tag).inc( - resource_usage.ru_utime, + response_ru_utime.labels(self.method, self.name, tag).inc( + resource_usage.ru_utime ) - response_ru_stime.labels(request.method, self.name, tag).inc( - resource_usage.ru_stime, + response_ru_stime.labels(self.method, self.name, tag).inc( + resource_usage.ru_stime ) - response_db_txn_count.labels(request.method, self.name, tag).inc( + response_db_txn_count.labels(self.method, self.name, tag).inc( resource_usage.db_txn_count ) - response_db_txn_duration.labels(request.method, self.name, tag).inc( + response_db_txn_duration.labels(self.method, self.name, tag).inc( resource_usage.db_txn_duration_sec ) - response_db_sched_duration.labels(request.method, self.name, tag).inc( + response_db_sched_duration.labels(self.method, self.name, tag).inc( resource_usage.db_sched_duration_sec ) - response_size.labels(request.method, self.name, tag).inc(request.sentLength) + response_size.labels(self.method, self.name, tag).inc(sent_bytes) # We always call this at the end to ensure that we update the metrics # regardless of whether a call to /metrics while the request was in @@ -215,8 +220,15 @@ class RequestMetrics(object): diff = new_stats - self._request_stats self._request_stats = new_stats - in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime) - in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime) + # max() is used since rapid use of ru_stime/ru_utime can end up with the + # count going backwards due to NTP, time smearing, fine-grained + # correction, or floating points. Who knows, really? + in_flight_requests_ru_utime.labels(self.method, self.name).inc( + max(diff.ru_utime, 0) + ) + in_flight_requests_ru_stime.labels(self.method, self.name).inc( + max(diff.ru_stime, 0) + ) in_flight_requests_db_txn_count.labels(self.method, self.name).inc( diff.db_txn_count diff --git a/synapse/http/server.py b/synapse/http/server.py index 2d5c23e673..b4b25cab19 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -84,10 +84,21 @@ def wrap_json_request_handler(h): logger.info( "%s SynapseError: %s - %s", request, code, e.msg ) - respond_with_json( - request, code, e.error_dict(), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, code, e.error_dict(), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) except Exception: # failure.Failure() fishes the original Failure out @@ -100,16 +111,26 @@ def wrap_json_request_handler(h): request, f.getTraceback().rstrip(), ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) return wrap_async_request_handler(wrapped_request_handler) diff --git a/synapse/http/site.py b/synapse/http/site.py index 88ed3714f9..e508c0bd4f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -75,17 +75,35 @@ class SynapseRequest(Request): return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( self.__class__.__name__, id(self), - self.method, + self.get_method(), self.get_redacted_uri(), - self.clientproto, + self.clientproto.decode('ascii', errors='replace'), self.site.site_tag, ) def get_request_id(self): - return "%s-%i" % (self.method, self.request_seq) + return "%s-%i" % (self.get_method(), self.request_seq) def get_redacted_uri(self): - return redact_uri(self.uri) + uri = self.uri + if isinstance(uri, bytes): + uri = self.uri.decode('ascii') + return redact_uri(uri) + + def get_method(self): + """Gets the method associated with the request (or placeholder if not + method has yet been received). + + Note: This is necessary as the placeholder value in twisted is str + rather than bytes, so we need to sanitise `self.method`. + + Returns: + str + """ + method = self.method + if isinstance(method, bytes): + method = self.method.decode('ascii') + return method def get_user_agent(self): return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] @@ -116,7 +134,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.method, + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager @@ -204,14 +222,14 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.method, + self.start_time, name=servlet_name, method=self.get_method(), ) self.site.access_logger.info( "%s - %s - Received request: %s %s", self.getClientIP(), self.site.site_tag, - self.method, + self.get_method(), self.get_redacted_uri() ) @@ -277,15 +295,15 @@ class SynapseRequest(Request): int(usage.db_txn_count), self.sentLength, code, - self.method, + self.get_method(), self.get_redacted_uri(), - self.clientproto, + self.clientproto.decode('ascii', errors='replace'), user_agent, usage.evt_db_fetch_count, ) try: - self.request_metrics.stop(self.finish_time, self) + self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: logger.warn("Failed to stop metrics: %r", e) @@ -305,7 +323,7 @@ class XForwardedForRequest(SynapseRequest): C{b"-"}. """ return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() + b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') class SynapseRequestFactory(object): diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 550f8443f7..59900aa5d1 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -18,8 +18,11 @@ import gc import logging import os import platform +import threading import time +import six + import attr from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import REGISTRY, GaugeMetricFamily @@ -68,7 +71,7 @@ class LaterGauge(object): return if isinstance(calls, dict): - for k, v in calls.items(): + for k, v in six.iteritems(calls): g.add_metric(k, v) else: g.add_metric([], calls) @@ -87,6 +90,109 @@ class LaterGauge(object): all_gauges[self.name] = self +class InFlightGauge(object): + """Tracks number of things (e.g. requests, Measure blocks, etc) in flight + at any given time. + + Each InFlightGauge will create a metric called `<name>_total` that counts + the number of in flight blocks, as well as a metrics for each item in the + given `sub_metrics` as `<name>_<sub_metric>` which will get updated by the + callbacks. + + Args: + name (str) + desc (str) + labels (list[str]) + sub_metrics (list[str]): A list of sub metrics that the callbacks + will update. + """ + + def __init__(self, name, desc, labels, sub_metrics): + self.name = name + self.desc = desc + self.labels = labels + self.sub_metrics = sub_metrics + + # Create a class which have the sub_metrics values as attributes, which + # default to 0 on initialization. Used to pass to registered callbacks. + self._metrics_class = attr.make_class( + "_MetricsEntry", + attrs={x: attr.ib(0) for x in sub_metrics}, + slots=True, + ) + + # Counts number of in flight blocks for a given set of label values + self._registrations = {} + + # Protects access to _registrations + self._lock = threading.Lock() + + self._register_with_collector() + + def register(self, key, callback): + """Registers that we've entered a new block with labels `key`. + + `callback` gets called each time the metrics are collected. The same + value must also be given to `unregister`. + + `callback` gets called with an object that has an attribute per + sub_metric, which should be updated with the necessary values. Note that + the metrics object is shared between all callbacks registered with the + same key. + + Note that `callback` may be called on a separate thread. + """ + with self._lock: + self._registrations.setdefault(key, set()).add(callback) + + def unregister(self, key, callback): + """Registers that we've exited a block with labels `key`. + """ + + with self._lock: + self._registrations.setdefault(key, set()).discard(callback) + + def collect(self): + """Called by prometheus client when it reads metrics. + + Note: may be called by a separate thread. + """ + in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels) + + metrics_by_key = {} + + # We copy so that we don't mutate the list while iterating + with self._lock: + keys = list(self._registrations) + + for key in keys: + with self._lock: + callbacks = set(self._registrations[key]) + + in_flight.add_metric(key, len(callbacks)) + + metrics = self._metrics_class() + metrics_by_key[key] = metrics + for callback in callbacks: + callback(metrics) + + yield in_flight + + for name in self.sub_metrics: + gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels) + for key, metrics in six.iteritems(metrics_by_key): + gauge.add_metric(key, getattr(metrics, name)) + yield gauge + + def _register_with_collector(self): + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering" % (self.name,)) + REGISTRY.unregister(all_gauges.pop(self.name)) + + REGISTRY.register(self) + all_gauges[self.name] = self + + # # Detailed CPU metrics # diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index ce678d5f75..037f1c490e 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import threading + import six from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily @@ -21,6 +24,9 @@ from twisted.internet import defer from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +logger = logging.getLogger(__name__) + + _background_process_start_count = Counter( "synapse_background_process_start_count", "Number of background processes started", @@ -78,6 +84,9 @@ _background_process_counts = dict() # type: dict[str, int] # of process descriptions that no longer have any active processes. _background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +# A lock that covers the above dicts +_bg_metrics_lock = threading.Lock() + class _Collector(object): """A custom metrics collector for the background process metrics. @@ -92,7 +101,15 @@ class _Collector(object): labels=["name"], ) - for desc, processes in six.iteritems(_background_processes): + # We copy the dict so that it doesn't change from underneath us. + # We also copy the process lists as that can also change + with _bg_metrics_lock: + _background_processes_copy = { + k: list(v) + for k, v in six.iteritems(_background_processes) + } + + for desc, processes in six.iteritems(_background_processes_copy): background_process_in_flight_count.add_metric( (desc,), len(processes), ) @@ -167,19 +184,28 @@ def run_as_background_process(desc, func, *args, **kwargs): """ @defer.inlineCallbacks def run(): - count = _background_process_counts.get(desc, 0) - _background_process_counts[desc] = count + 1 + with _bg_metrics_lock: + count = _background_process_counts.get(desc, 0) + _background_process_counts[desc] = count + 1 + _background_process_start_count.labels(desc).inc() with LoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) proc = _BackgroundProcess(desc, context) - _background_processes.setdefault(desc, set()).add(proc) + + with _bg_metrics_lock: + _background_processes.setdefault(desc, set()).add(proc) + try: yield func(*args, **kwargs) + except Exception: + logger.exception("Background process '%s' threw an exception", desc) finally: proc.update_metrics() - _background_processes[desc].remove(proc) + + with _bg_metrics_lock: + _background_processes[desc].remove(proc) with PreserveLoggingContext(): return run() diff --git a/synapse/notifier.py b/synapse/notifier.py index 82f391481c..de02b1017e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -24,13 +24,10 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import StreamToken -from synapse.util.async_helpers import ( - DeferredTimeoutError, - ObservableDeferred, - add_timeout_to_deferred, -) -from synapse.util.logcontext import PreserveLoggingContext, run_in_background +from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -189,9 +186,9 @@ class Notifier(object): def count_listeners(): all_user_streams = set() - for x in self.room_to_user_streams.values(): + for x in list(self.room_to_user_streams.values()): all_user_streams |= x - for x in self.user_to_user_stream.values(): + for x in list(self.user_to_user_stream.values()): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) @@ -199,7 +196,7 @@ class Notifier(object): LaterGauge( "synapse_notifier_rooms", "", [], - lambda: count(bool, self.room_to_user_streams.values()), + lambda: count(bool, list(self.room_to_user_streams.values())), ) LaterGauge( "synapse_notifier_users", "", [], @@ -252,7 +249,10 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - run_in_background(self._notify_app_services, room_stream_id) + run_as_background_process( + "notify_app_services", + self._notify_app_services, room_stream_id, + ) if self.federation_sender: self.federation_sender.notify_new_events(room_stream_id) @@ -337,7 +337,7 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) - add_timeout_to_deferred( + listener.deferred = timeout_deferred( listener.deferred, (end_time - now) / 1000., self.hs.get_reactor(), @@ -354,7 +354,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break @@ -559,15 +559,16 @@ class Notifier(object): if end_time <= now: break - add_timeout_to_deferred( - listener.deferred.addTimeout, - (end_time - now) / 1000., - self.hs.get_reactor(), + listener.deferred = timeout_deferred( + listener.deferred, + timeout=(end_time - now) / 1000., + reactor=self.hs.get_reactor(), ) + try: with PreserveLoggingContext(): yield listener.deferred - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index d746371420..f369124258 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -18,8 +18,7 @@ import logging from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure +from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) @@ -71,18 +70,11 @@ class EmailPusher(object): # See httppusher self.max_stream_ordering = None - self.processing = False + self._is_processing = False - @defer.inlineCallbacks def on_started(self): if self.mailer is not None: - try: - self.throttle_params = yield self.store.get_throttle_params_by_room( - self.pusher_id - ) - yield self._process() - except Exception: - logger.exception("Error starting email pusher") + self._start_processing() def on_stop(self): if self.timed_call: @@ -92,43 +84,52 @@ class EmailPusher(object): pass self.timed_call = None - @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) - yield self._process() + self._start_processing() def on_new_receipts(self, min_stream_id, max_stream_id): # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the # timer fire - return defer.succeed(None) + pass - @defer.inlineCallbacks def on_timer(self): self.timed_call = None - yield self._process() + self._start_processing() + + def _start_processing(self): + if self._is_processing: + return + + run_as_background_process("emailpush.process", self._process) @defer.inlineCallbacks def _process(self): - if self.processing: - return + # we should never get here if we are already processing + assert not self._is_processing + + try: + self._is_processing = True + + if self.throttle_params is None: + # this is our first loop: load up the throttle params + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) - with LoggingContext("emailpush._process"): - with Measure(self.clock, "emailpush._process"): + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering try: - self.processing = True - # if the max ordering changes while we're running _unsafe_process, - # call it again, and so on until we've caught up. - while True: - starting_max_ordering = self.max_stream_ordering - try: - yield self._unsafe_process() - except Exception: - logger.exception("Exception processing notifs") - if self.max_stream_ordering == starting_max_ordering: - break - finally: - self.processing = False + yield self._unsafe_process() + except Exception: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self._is_processing = False @defer.inlineCallbacks def _unsafe_process(self): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 81e18bcf7d..6bd703632d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,17 +15,21 @@ # limitations under the License. import logging +import six + from prometheus_client import Counter from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure from . import push_rule_evaluator, push_tools +if six.PY3: + long = int + logger = logging.getLogger(__name__) http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "") @@ -56,7 +60,7 @@ class HttpPusher(object): self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusherdict['failing_since'] self.timed_call = None - self.processing = False + self._is_processing = False # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we @@ -87,34 +91,27 @@ class HttpPusher(object): self.data_minus_url.update(self.data) del self.data_minus_url['url'] - @defer.inlineCallbacks def on_started(self): - try: - yield self._process() - except Exception: - logger.exception("Error starting http pusher") + self._start_processing() - @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) - yield self._process() + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) + self._start_processing() - @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id): # Note that the min here shouldn't be relied upon to be accurate. # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - with LoggingContext("push.on_new_receipts"): - with Measure(self.clock, "push.on_new_receipts"): - badge = yield push_tools.get_badge_count( - self.hs.get_datastore(), self.user_id - ) - yield self._send_badge(badge) + run_as_background_process("http_pusher.on_new_receipts", self._update_badge) @defer.inlineCallbacks + def _update_badge(self): + badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + yield self._send_badge(badge) + def on_timer(self): - yield self._process() + self._start_processing() def on_stop(self): if self.timed_call: @@ -124,27 +121,31 @@ class HttpPusher(object): pass self.timed_call = None + def _start_processing(self): + if self._is_processing: + return + + run_as_background_process("httppush.process", self._process) + @defer.inlineCallbacks def _process(self): - if self.processing: - return + # we should never get here if we are already processing + assert not self._is_processing - with LoggingContext("push._process"): - with Measure(self.clock, "push._process"): + try: + self._is_processing = True + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering try: - self.processing = True - # if the max ordering changes while we're running _unsafe_process, - # call it again, and so on until we've caught up. - while True: - starting_max_ordering = self.max_stream_ordering - try: - yield self._unsafe_process() - except Exception: - logger.exception("Exception processing notifs") - if self.max_stream_ordering == starting_max_ordering: - break - finally: - self.processing = False + yield self._unsafe_process() + except Exception: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self._is_processing = False @defer.inlineCallbacks def _unsafe_process(self): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index bfa6df7b68..16fb5e8471 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -17,10 +17,11 @@ import email.mime.multipart import email.utils import logging import time -import urllib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from six.moves import urllib + import bleach import jinja2 @@ -440,7 +441,7 @@ class Mailer(object): def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: - base_url = self.hs.config.email_riot_base_url + base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS base_url = "https://vector.im/beta/#/room" @@ -474,7 +475,7 @@ class Mailer(object): # XXX: make r0 once API is stable return "%s_matrix/client/unstable/pushers/remove?%s" % ( self.hs.config.public_baseurl, - urllib.urlencode(params), + urllib.parse.urlencode(params), ) @@ -525,8 +526,7 @@ def load_jinja2_templates(config): Returns: (notif_template_html, notif_template_text) """ - logger.info("loading jinja2") - + logger.info("loading email templates from '%s'", config.email_template_dir) loader = jinja2.FileSystemLoader(config.email_template_dir) env = jinja2.Environment(loader=loader) env.filters["format_ts"] = format_ts_filter @@ -561,7 +561,7 @@ def _create_mxc_to_http_filter(config): return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( config.public_baseurl, serverAndMediaId, - urllib.urlencode(params), + urllib.parse.urlencode(params), fragment or "", ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 9f7d5ef217..5a4e73ccd6 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -20,24 +20,39 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push.pusher import PusherFactory -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) class PusherPool: + """ + The pusher pool. This is responsible for dispatching notifications of new events to + the http and email pushers. + + It provides three methods which are designed to be called by the rest of the + application: `start`, `on_new_notifications`, and `on_new_receipts`: each of these + delegates to each of the relevant pushers. + + Note that it is expected that each pusher will have its own 'processing' loop which + will send out the notifications in the background, rather than blocking until the + notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and + Pusher.on_new_receipts are not expected to return deferreds. + """ def __init__(self, _hs): self.hs = _hs self.pusher_factory = PusherFactory(_hs) - self.start_pushers = _hs.config.start_pushers + self._should_start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.pushers = {} - @defer.inlineCallbacks def start(self): - pushers = yield self.store.get_all_pushers() - self._start_pushers(pushers) + """Starts the pushers off in a background process. + """ + if not self._should_start_pushers: + logger.info("Not starting pushers because they are disabled in the config") + return + run_as_background_process("start_pushers", self._start_pushers) @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, @@ -86,7 +101,7 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - yield self._refresh_pusher(app_id, pushkey, user_id) + yield self.start_pusher_by_id(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -123,45 +138,23 @@ class PusherPool: p['app_id'], p['pushkey'], p['user_name'], ) - def on_new_notifications(self, min_stream_id, max_stream_id): - run_as_background_process( - "on_new_notifications", - self._on_new_notifications, min_stream_id, max_stream_id, - ) - @defer.inlineCallbacks - def _on_new_notifications(self, min_stream_id, max_stream_id): + def on_new_notifications(self, min_stream_id, max_stream_id): try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) - deferreds = [] - for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - deferreds.append( - run_in_background( - p.on_new_notifications, - min_stream_id, max_stream_id, - ) - ) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True), - ) + p.on_new_notifications(min_stream_id, max_stream_id) + except Exception: logger.exception("Exception in pusher on_new_notifications") - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): - run_as_background_process( - "on_new_receipts", - self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, - ) - @defer.inlineCallbacks - def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive @@ -171,26 +164,20 @@ class PusherPool: # This returns a tuple, user_id is at index 3 users_affected = set([r[3] for r in updated_receipts]) - deferreds = [] - for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - deferreds.append( - run_in_background( - p.on_new_receipts, - min_stream_id, max_stream_id, - ) - ) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True), - ) + p.on_new_receipts(min_stream_id, max_stream_id) + except Exception: logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks - def _refresh_pusher(self, app_id, pushkey, user_id): + def start_pusher_by_id(self, app_id, pushkey, user_id): + """Look up the details for the given pusher, and start it""" + if not self._should_start_pushers: + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( app_id, pushkey ) @@ -201,33 +188,49 @@ class PusherPool: p = r if p: + self._start_pusher(p) - self._start_pushers([p]) + @defer.inlineCallbacks + def _start_pushers(self): + """Start all the pushers - def _start_pushers(self, pushers): - if not self.start_pushers: - logger.info("Not starting pushers because they are disabled in the config") - return + Returns: + Deferred + """ + pushers = yield self.store.get_all_pushers() logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: - try: - p = self.pusher_factory.create_pusher(pusherdict) - except Exception: - logger.exception("Couldn't start a pusher: caught Exception") - continue - if p: - appid_pushkey = "%s:%s" % ( - pusherdict['app_id'], - pusherdict['pushkey'], - ) - byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + self._start_pusher(pusherdict) + logger.info("Started pushers") - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p - run_in_background(p.on_started) + def _start_pusher(self, pusherdict): + """Start the given pusher - logger.info("Started pushers") + Args: + pusherdict (dict): + + Returns: + None + """ + try: + p = self.pusher_factory.create_pusher(pusherdict) + except Exception: + logger.exception("Couldn't start a pusher: caught Exception") + return + + if not p: + return + + appid_pushkey = "%s:%s" % ( + pusherdict['app_id'], + pusherdict['pushkey'], + ) + byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + p.on_started() @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 987eec3ef2..943876456b 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -33,31 +33,39 @@ logger = logging.getLogger(__name__) # [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], - "frozendict>=0.4": ["frozendict"], + "frozendict>=1": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], - "service_identity>=1.0.0": ["service_identity>=1.0.0"], - "Twisted>=16.0.0": ["twisted>=16.0.0"], - - # We use crypto.get_elliptic_curve which is only supported in >=0.15 - "pyopenssl>=0.15": ["OpenSSL>=0.15"], - - "pyyaml": ["yaml"], - "pyasn1": ["pyasn1"], - "daemonize": ["daemonize"], - "bcrypt": ["bcrypt>=3.1.0"], - "pillow": ["PIL"], - "pydenticon": ["pydenticon"], - "sortedcontainers": ["sortedcontainers"], - "pysaml2>=3.0.0": ["saml2>=3.0.0"], - "pymacaroons-pynacl": ["pymacaroons"], - "msgpack-python>=0.3.0": ["msgpack"], + "service_identity>=16.0.0": ["service_identity>=16.0.0"], + "Twisted>=17.1.0": ["twisted>=17.1.0"], + "treq>=15.1": ["treq>=15.1"], + + # Twisted has required pyopenssl 16.0 since about Twisted 16.6. + "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"], + + "pyyaml>=3.11": ["yaml"], + "pyasn1>=0.1.9": ["pyasn1"], + "pyasn1-modules>=0.0.7": ["pyasn1_modules"], + "daemonize>=2.3.1": ["daemonize"], + "bcrypt>=3.1.0": ["bcrypt>=3.1.0"], + "pillow>=3.1.2": ["PIL"], + "pydenticon>=0.2": ["pydenticon"], + "sortedcontainers>=1.4.4": ["sortedcontainers"], + "psutil>=2.0.0": ["psutil>=2.0.0"], + "pysaml2>=3.0.0": ["saml2"], + "pymacaroons-pynacl>=0.9.3": ["pymacaroons"], + "msgpack-python>=0.4.2": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], - "six": ["six"], - "prometheus_client": ["prometheus_client"], - "attrs": ["attr"], + "six>=1.10": ["six"], + + # prometheus_client 0.4.0 changed the format of counter metrics + # (cf https://github.com/matrix-org/synapse/issues/4001) + "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"], + + # we use attr.s(slots), which arrived in 16.0.0 + "attrs>=16.0.0": ["attr>=16.0.0"], "netaddr>=0.7.18": ["netaddr"], } @@ -72,12 +80,9 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": { "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], }, - "psutil": { - "psutil>=2.0.0": ["psutil>=2.0.0"], - }, - "affinity": { - "affinity": ["affinity"], - }, + "postgres": { + "psycopg2>=2.6": ["psycopg2"] + } } diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 3f7be74e02..2d81d49e9a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,6 +15,8 @@ import logging +import six + from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine @@ -23,6 +25,13 @@ from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) +def __func__(inp): + if six.PY3: + return inp + else: + return inp.__func__ + + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(db_conn, hs) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 87eaa53004..4f19fd35aa 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -17,7 +17,7 @@ from synapse.storage import DataStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -43,11 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ - get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ - get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ - delete_messages_for_device = DataStore.delete_messages_for_device.__func__ - delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ + get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token) + get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device) + get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote) + delete_messages_for_device = __func__(DataStore.delete_messages_for_device) + delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote) def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 8206a988f7..ec2fd561cc 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -17,7 +17,7 @@ from synapse.storage import DataStore from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -38,14 +38,14 @@ class SlavedDeviceStore(BaseSlavedStore): "DeviceListFederationStreamChangeCache", device_list_max, ) - get_device_stream_token = DataStore.get_device_stream_token.__func__ - get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__ - get_devices_by_remote = DataStore.get_devices_by_remote.__func__ - _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__ - _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__ - mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__ + get_device_stream_token = __func__(DataStore.get_device_stream_token) + get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) + get_devices_by_remote = __func__(DataStore.get_devices_by_remote) + _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) + _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) + mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) _mark_as_sent_devices_by_remote_txn = ( - DataStore._mark_as_sent_devices_by_remote_txn.__func__ + __func__(DataStore._mark_as_sent_devices_by_remote_txn) ) count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 5777f07c8d..e933b170bb 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -16,7 +16,7 @@ from synapse.storage import DataStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -33,9 +33,9 @@ class SlavedGroupServerStore(BaseSlavedStore): "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), ) - get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ - get_group_stream_token = DataStore.get_group_stream_token.__func__ - get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) + get_group_stream_token = __func__(DataStore.get_group_stream_token) + get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user) def stream_positions(self): result = super(SlavedGroupServerStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 05ed168463..8032f53fec 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -16,7 +16,7 @@ from synapse.storage import DataStore from synapse.storage.keys import KeyStore -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ class SlavedKeyStore(BaseSlavedStore): @@ -24,11 +24,11 @@ class SlavedKeyStore(BaseSlavedStore): "_get_server_verify_key" ] - get_server_verify_keys = DataStore.get_server_verify_keys.__func__ - store_server_verify_key = DataStore.store_server_verify_key.__func__ + get_server_verify_keys = __func__(DataStore.get_server_verify_keys) + store_server_verify_key = __func__(DataStore.store_server_verify_key) - get_server_certificate = DataStore.get_server_certificate.__func__ - store_server_certificate = DataStore.store_server_certificate.__func__ + get_server_certificate = __func__(DataStore.get_server_certificate) + store_server_certificate = __func__(DataStore.store_server_certificate) - get_server_keys_json = DataStore.get_server_keys_json.__func__ - store_server_keys_json = DataStore.store_server_keys_json.__func__ + get_server_keys_json = __func__(DataStore.get_server_keys_json) + store_server_keys_json = __func__(DataStore.store_server_keys_json) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 80b744082a..92447b00d4 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -17,7 +17,7 @@ from synapse.storage import DataStore from synapse.storage.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +34,8 @@ class SlavedPresenceStore(BaseSlavedStore): "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() ) - _get_active_presence = DataStore._get_active_presence.__func__ - take_presence_startup_info = DataStore.take_presence_startup_info.__func__ + _get_active_presence = __func__(DataStore._get_active_presence) + take_presence_startup_info = __func__(DataStore.take_presence_startup_info) _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 74e892c104..5dc7b3fffc 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections + (p.name,): len(p.pending_commands) for p in connected_connections }, ) @@ -607,9 +607,9 @@ def transport_buffer_size(protocol): transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections + (p.name,): transport_buffer_size(p) for p in connected_connections }, ) @@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True): tcp_transport_kernel_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + (p.name,): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, ) @@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge( tcp_transport_kernel_read_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_read_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + (p.name,): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, ) @@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge( tcp_inbound_commands = LaterGauge( "synapse_replication_tcp_protocol_inbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge( "synapse_replication_tcp_protocol_outbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 55fe701c5c..c1e626be3f 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -196,7 +196,7 @@ class Stream(object): ) if len(rows) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behined" % (self.NAME)) + raise Exception("stream %s has fallen behind" % (self.NAME)) else: rows = yield self.update_function( from_token, current_token, diff --git a/synapse/res/templates/mail-Vector.css b/synapse/res/templates/mail-Vector.css new file mode 100644 index 0000000000..6a3e36eda1 --- /dev/null +++ b/synapse/res/templates/mail-Vector.css @@ -0,0 +1,7 @@ +.header { + border-bottom: 4px solid #e4f7ed ! important; +} + +.notif_link a, .footer a { + color: #76CFA6 ! important; +} diff --git a/synapse/res/templates/mail.css b/synapse/res/templates/mail.css new file mode 100644 index 0000000000..5ab3e1b06d --- /dev/null +++ b/synapse/res/templates/mail.css @@ -0,0 +1,156 @@ +body { + margin: 0px; +} + +pre, code { + word-break: break-word; + white-space: pre-wrap; +} + +#page { + font-family: 'Open Sans', Helvetica, Arial, Sans-Serif; + font-color: #454545; + font-size: 12pt; + width: 100%; + padding: 20px; +} + +#inner { + width: 640px; +} + +.header { + width: 100%; + height: 87px; + color: #454545; + border-bottom: 4px solid #e5e5e5; +} + +.logo { + text-align: right; + margin-left: 20px; +} + +.salutation { + padding-top: 10px; + font-weight: bold; +} + +.summarytext { +} + +.room { + width: 100%; + color: #454545; + border-bottom: 1px solid #e5e5e5; +} + +.room_header td { + padding-top: 38px; + padding-bottom: 10px; + border-bottom: 1px solid #e5e5e5; +} + +.room_name { + vertical-align: middle; + font-size: 18px; + font-weight: bold; +} + +.room_header h2 { + margin-top: 0px; + margin-left: 75px; + font-size: 20px; +} + +.room_avatar { + width: 56px; + line-height: 0px; + text-align: center; + vertical-align: middle; +} + +.room_avatar img { + width: 48px; + height: 48px; + object-fit: cover; + border-radius: 24px; +} + +.notif { + border-bottom: 1px solid #e5e5e5; + margin-top: 16px; + padding-bottom: 16px; +} + +.historical_message .sender_avatar { + opacity: 0.3; +} + +/* spell out opacity and historical_message class names for Outlook aka Word */ +.historical_message .sender_name { + color: #e3e3e3; +} + +.historical_message .message_time { + color: #e3e3e3; +} + +.historical_message .message_body { + color: #c7c7c7; +} + +.historical_message td, +.message td { + padding-top: 10px; +} + +.sender_avatar { + width: 56px; + text-align: center; + vertical-align: top; +} + +.sender_avatar img { + margin-top: -2px; + width: 32px; + height: 32px; + border-radius: 16px; +} + +.sender_name { + display: inline; + font-size: 13px; + color: #a2a2a2; +} + +.message_time { + text-align: right; + width: 100px; + font-size: 11px; + color: #a2a2a2; +} + +.message_body { +} + +.notif_link td { + padding-top: 10px; + padding-bottom: 10px; + font-weight: bold; +} + +.notif_link a, .footer a { + color: #454545; + text-decoration: none; +} + +.debug { + font-size: 10px; + color: #888; +} + +.footer { + margin-top: 20px; + text-align: center; +} \ No newline at end of file diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html new file mode 100644 index 0000000000..88b921ca9c --- /dev/null +++ b/synapse/res/templates/notif.html @@ -0,0 +1,45 @@ +{% for message in notif.messages %} + <tr class="{{ "historical_message" if message.is_historical else "message" }}"> + <td class="sender_avatar"> + {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + {% if message.sender_avatar_url %} + <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" /> + {% else %} + {% if message.sender_hash % 3 == 0 %} + <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" /> + {% elif message.sender_hash % 3 == 1 %} + <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" /> + {% else %} + <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" /> + {% endif %} + {% endif %} + {% endif %} + </td> + <td class="message_contents"> + {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div> + {% endif %} + <div class="message_body"> + {% if message.msgtype == "m.text" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.emote" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.notice" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.image" %} + <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> + {% elif message.msgtype == "m.file" %} + <span class="filename">{{ message.body_text_plain }}</span> + {% endif %} + </div> + </td> + <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td> + </tr> +{% endfor %} +<tr class="notif_link"> + <td></td> + <td> + <a href="{{ notif.link }}">View {{ room.title }}</a> + </td> + <td></td> +</tr> diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt new file mode 100644 index 0000000000..a37bee9833 --- /dev/null +++ b/synapse/res/templates/notif.txt @@ -0,0 +1,16 @@ +{% for message in notif.messages %} +{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) +{% if message.msgtype == "m.text" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.emote" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.notice" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.image" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.file" %} +{{ message.body_text_plain }} +{% endif %} +{% endfor %} + +View {{ room.title }} at {{ notif.link }} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html new file mode 100644 index 0000000000..fcdb3109fe --- /dev/null +++ b/synapse/res/templates/notif_mail.html @@ -0,0 +1,55 @@ +<!doctype html> +<html lang="en"> + <head> + <style type="text/css"> + {% include 'mail.css' without context %} + {% include "mail-%s.css" % app_name ignore missing without context %} + </style> + </head> + <body> + <table id="page"> + <tr> + <td> </td> + <td id="inner"> + <table class="header"> + <tr> + <td> + <div class="salutation">Hi {{ user_display_name }},</div> + <div class="summarytext">{{ summary_text }}</div> + </td> + <td class="logo"> + {% if app_name == "Riot" %} + <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> + {% elif app_name == "Vector" %} + <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% else %} + <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> + {% endif %} + </td> + </tr> + </table> + {% for room in rooms %} + {% include 'room.html' with context %} + {% endfor %} + <div class="footer"> + <a href="{{ unsubscribe_link }}">Unsubscribe</a> + <br/> + <br/> + <div class="debug"> + Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because + an event was received at {{ reason.received_at|format_ts("%c") }} + which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, + {% if reason.last_sent_ts %} + and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, + which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. + {% else %} + and we don't have a last time we sent a mail for this room. + {% endif %} + </div> + </div> + </td> + <td> </td> + </tr> + </table> + </body> +</html> diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt new file mode 100644 index 0000000000..24843042a5 --- /dev/null +++ b/synapse/res/templates/notif_mail.txt @@ -0,0 +1,10 @@ +Hi {{ user_display_name }}, + +{{ summary_text }} + +{% for room in rooms %} +{% include 'room.txt' with context %} +{% endfor %} + +You can disable these notifications at {{ unsubscribe_link }} + diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html new file mode 100644 index 0000000000..723c222d25 --- /dev/null +++ b/synapse/res/templates/room.html @@ -0,0 +1,33 @@ +<table class="room"> + <tr class="room_header"> + <td class="room_avatar"> + {% if room.avatar_url %} + <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" /> + {% else %} + {% if room.hash % 3 == 0 %} + <img alt="" src="https://vector.im/beta/img/76cfa6.png" /> + {% elif room.hash % 3 == 1 %} + <img alt="" src="https://vector.im/beta/img/50e2c2.png" /> + {% else %} + <img alt="" src="https://vector.im/beta/img/f4c371.png" /> + {% endif %} + {% endif %} + </td> + <td class="room_name" colspan="2"> + {{ room.title }} + </td> + </tr> + {% if room.invite %} + <tr> + <td></td> + <td> + <a href="{{ room.link }}">Join the conversation.</a> + </td> + <td></td> + </tr> + {% else %} + {% for notif in room.notifs %} + {% include 'notif.html' with context %} + {% endfor %} + {% endif %} +</table> diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt new file mode 100644 index 0000000000..84648c710e --- /dev/null +++ b/synapse/res/templates/room.txt @@ -0,0 +1,9 @@ +{{ room.title }} + +{% if room.invite %} + You've been invited, join at {{ room.link }} +{% else %} + {% for notif in room.notifs %} + {% include 'notif.txt' with context %} + {% endfor %} +{% endif %} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3418f06fd6..4856822a5d 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( receipts, register, report_event, + room_keys, sendtodevice, sync, tags, @@ -102,6 +103,7 @@ class ClientRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) read_marker.register_servlets(hs, client_resource) + room_keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index ad536ab570..41534b8c2a 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet): nonce = self.hs.get_secrets().token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return (200, {"nonce": nonce.encode('ascii')}) + return (200, {"nonce": nonce}) @defer.inlineCallbacks def on_POST(self, request): @@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce) + want_mac.update(nonce.encode('utf8')) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + if not hmac.compare_digest( + want_mac.encode('ascii'), + got_mac.encode('ascii') + ): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 97733f3026..0220acf644 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -74,38 +74,11 @@ class ClientDirectoryServer(ClientV1RestServlet): if room is None: raise SynapseError(400, "Room does not exist") - dir_handler = self.handlers.directory_handler + requester = yield self.auth.get_user_by_req(request) - try: - # try to auth as a user - requester = yield self.auth.get_user_by_req(request) - try: - user_id = requester.user.to_string() - yield dir_handler.create_association( - user_id, room_alias, room_id, servers - ) - yield dir_handler.send_room_alias_update_event( - requester, - user_id, - room_id - ) - except SynapseError as e: - raise e - except Exception: - logger.exception("Failed to create association") - raise - except AuthError: - # try to auth as an application service - service = yield self.auth.get_appservice_by_req(request) - yield dir_handler.create_appservice_association( - service, room_alias, room_id, servers - ) - logger.info( - "Application service at %s created alias %s pointing to %s", - service.url, - room_alias.to_string(), - room_id - ) + yield self.handlers.directory_handler.create_association( + requester, room_alias, room_id, servers + ) defer.returnValue((200, {})) @@ -135,7 +108,7 @@ class ClientDirectoryServer(ClientV1RestServlet): room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( - requester, user.to_string(), room_alias + requester, room_alias ) logger.info( diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 0f3a2e8b51..cd9b3bdbd1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet): is_guest = requester.is_guest room_id = None if is_guest: - if "room_id" not in request.args: + if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") - if "room_id" in request.args: - room_id = request.args["room_id"][0] + if b"room_id" in request.args: + room_id = request.args[b"room_id"][0].decode('ascii') pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: + if b"timeout" in request.args: try: - timeout = int(request.args["timeout"][0]) + timeout = int(request.args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index fd5f85b53e..3ead75cb77 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) content = yield self.initial_sync_handler.snapshot_all_rooms( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cb85fa1436..0010699d31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,9 @@ # limitations under the License. import logging -import urllib import xml.etree.ElementTree as ET -from six.moves.urllib import parse as urlparse +from six.moves import urllib from canonicaljson import json from saml2 import BINDING_HTTP_POST, config @@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: - relay_state = "&RelayState=" + urllib.quote( + relay_state = "&RelayState=" + urllib.parse.quote( login_submission["relay_state"]) result = { "uri": "%s%s" % (self.idp_redirect_url, relay_state) @@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet): (user_id, token) = yield handler.register_saml2(username) # Forward to the RelayState callback along with ava if 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + @@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet): "user_id": user_id, "token": token, "ava": saml2_auth.ava})) elif 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') finish_request(request) @@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url + self.cas_server_url = hs.config.cas_server_url.encode('ascii') + self.cas_service_url = hs.config.cas_service_url.encode('ascii') def on_GET(self, request): args = request.args - if "redirectUrl" not in args: + if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.urlencode({ - "redirectUrl": args["redirectUrl"][0] - }) - hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" - service_param = urllib.urlencode({ - "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) - }) - request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) + client_redirect_url_param = urllib.parse.urlencode({ + b"redirectUrl": args[b"redirectUrl"][0] + }).encode('ascii') + hs_redirect_url = (self.cas_service_url + + b"/_matrix/client/api/v1/login/cas/ticket") + service_param = urllib.parse.urlencode({ + b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) + }).encode('ascii') + request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - client_redirect_url = request.args["redirectUrl"][0] + client_redirect_url = request.args[b"redirectUrl"][0] http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { - "ticket": request.args["ticket"], + "ticket": request.args[b"ticket"][0].decode('ascii'), "service": self.cas_service_url } try: @@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet): finish_request(request) def add_login_token_to_redirect_url(self, url, token): - url_parts = list(urlparse.urlparse(url)) - query = dict(urlparse.parse_qsl(url_parts[4])) + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) query.update({"loginToken": token}) - url_parts[4] = urllib.urlencode(query) - return urlparse.urlunparse(url_parts) + url_parts[4] = urllib.parse.urlencode(query).encode('ascii') + return urllib.parse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): user = None diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6e95d9bec2..9382b1f124 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet): try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) requester = yield self.auth.get_user_by_req(request) @@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet): content, ) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) before = parse_string(request, "before") if before: @@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet): ) self.notify_user(user_id) except InconsistentRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) except RuleNotFoundException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) defer.returnValue((200, {})) @@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == b'': defer.returnValue((200, rules)) - elif path[0] == 'global': - path = path[1:] + elif path[0] == b'global': + path = [x.decode('ascii') for x in path[1:]] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) else: @@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet): def _rule_spec_from_path(path): if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != b'pushrules': raise UnrecognizedRequestError() - scope = path[1] + scope = path[1].decode('ascii') path = path[2:] if scope != 'global': raise UnrecognizedRequestError() @@ -203,13 +203,13 @@ def _rule_spec_from_path(path): if len(path) == 0: raise UnrecognizedRequestError() - template = path[0] + template = path[0].decode('ascii') path = path[1:] if len(path) == 0 or len(path[0]) == 0: raise UnrecognizedRequestError() - rule_id = path[0] + rule_id = path[0].decode('ascii') spec = { 'scope': scope, @@ -220,7 +220,7 @@ def _rule_spec_from_path(path): path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec['attr'] = path[0].decode('ascii') return spec diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 182a68b1e2..b84f0260f2 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet): ] for p in pushers: - for k, v in p.items(): + for k, v in list(p.items()): if k not in allowed_keys: del p[k] @@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet): profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + pce.message, + raise SynapseError(400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM) self.notifier.on_new_replication_data() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fcc1091760..fcfe7857f6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -33,6 +33,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID @@ -207,7 +208,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), } - if 'ts' in request.args and requester.app_service: + if b'ts' in request.args and requester.app_service: event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event = yield self.event_creation_hander.create_and_send_nonmember_event( @@ -255,7 +256,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = request.args["server_name"] + remote_room_hosts = [ + x.decode('ascii') for x in request.args[b"server_name"] + ] except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): @@ -407,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, - types=[(EventTypes.Member, None)], + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) chunk = [] @@ -461,10 +464,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): pagination_config = PaginationConfig.from_request( request, default_limit=10, ) - as_client_event = "raw" not in request.args - filter_bytes = parse_string(request, "filter") + as_client_event = b"raw" not in request.args + filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -531,7 +534,7 @@ class RoomEventServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) event = yield self.event_handler.get_event(requester.user, room_id, event_id) time_now = self.clock.time_msec() @@ -560,7 +563,7 @@ class RoomEventContextServlet(ClientV1RestServlet): # picking the API shape for symmetry with /messages filter_bytes = parse_string(request, "filter") if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes) event_filter = Filter(json.loads(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 62f4c3d93e..53da905eea 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -42,7 +42,11 @@ class VoipRestServlet(ClientV1RestServlet): expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 username = "%d:%s" % (expiry, requester.user.to_string()) - mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) + mac = hmac.new( + turnSecret.encode(), + msg=username.encode(), + digestmod=hashlib.sha1 + ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the # same result as the TURN server. diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py index 5e99cffbcb..dadb376b02 100644 --- a/synapse/rest/client/v1_only/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -23,6 +23,7 @@ from twisted.internet import defer import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.rest.client.v1.base import ClientV1RestServlet from synapse.types import create_requester @@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet): register_json["user"].encode("utf-8") if "user" in register_json else None ) + threepid = None + if session.get(LoginType.EMAIL_IDENTITY): + threepid = session["threepidCreds"] handler = self.handlers.registration_handler (user_id, token) = yield handler.register( localpart=desired_user_id, - password=password + password=password, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(user_id) if session[LoginType.EMAIL_IDENTITY]: logger.debug("Binding emails %s to %s" % ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index dad0433ba1..ea84729915 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "email", body['email'])): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "email", body['email'])): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index bd8b5f4afa..693b303881 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -99,7 +99,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7bbebd54ab..c7c8287882 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -28,6 +28,7 @@ import synapse import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -76,7 +77,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "email", body['email'])): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -116,7 +119,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Phone numbers are not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -369,7 +374,9 @@ class RegisterRestServlet(RestServlet): if not (yield check_3pid_allowed(self.hs, medium, address)): raise SynapseError( - 403, "Third party identifier is not allowed", + 403, + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) @@ -475,13 +482,22 @@ class RegisterRestServlet(RestServlet): if desired_username is not None: desired_username = desired_username.lower() + threepid = None + if auth_result: + threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, generate_token=False, display_name=desired_display_name, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py new file mode 100644 index 0000000000..45b5817d8b --- /dev/null +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RoomKeysServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, session_id): + """ + Uploads one or more encrypted E2E room keys for backup purposes. + room_id: the ID of the room the keys are for (optional) + session_id: the ID for the E2E room keys for the room (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /room_keys/version API. + + Each session has: + * first_message_index: a numeric index indicating the oldest message + encrypted by this session. + * forwarded_count: how many times the uploading client claims this key + has been shared (forwarded) + * is_verified: whether the client that uploaded the keys claims they + were sent by a device which they've verified + * session_data: base64-encrypted data describing the session. + + Returns 200 OK on success with body {} + Returns 403 Forbidden if the version in question is not the most recently + created version (i.e. if this is an old client trying to write to a stale backup) + Returns 404 Not Found if the version in question doesn't exist + + The API is designed to be otherwise agnostic to the room_key encryption + algorithm being used. Sessions are merged with existing ones in the + backup using the heuristics: + * is_verified sessions always win over unverified sessions + * older first_message_index always win over newer sessions + * lower forwarded_count always wins over higher forwarded_count + + We trust the clients not to lie and corrupt their own backups. + It also means that if your access_token is stolen, the attacker could + delete your backup. + + POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + Content-Type: application/json + + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + Content-Type: application/json + + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + POST /room_keys/keys?version=1 HTTP/1.1 + Content-Type: application/json + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + version = parse_string(request, "version") + + if session_id: + body = { + "sessions": { + session_id: body + } + } + + if room_id: + body = { + "rooms": { + room_id: body + } + } + + yield self.e2e_room_keys_handler.upload_room_keys( + user_id, version, body + ) + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_GET(self, request, room_id, session_id): + """ + Retrieves one or more encrypted E2E room keys for backup purposes. + Symmetric with the PUT version of the API. + + room_id: the ID of the room to retrieve the keys for (optional) + session_id: the ID for the E2E room keys to retrieve the keys for (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + + Returns as follows: + + GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + GET /room_keys/keys?version=1 HTTP/1.1 + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + room_keys = yield self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ) + + if session_id: + room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + elif room_id: + room_keys = room_keys['rooms'][room_id] + + defer.returnValue((200, room_keys)) + + @defer.inlineCallbacks + def on_DELETE(self, request, room_id, session_id): + """ + Deletes one or more encrypted E2E room keys for a user for backup purposes. + + DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 + HTTP/1.1 200 OK + {} + + room_id: the ID of the room whose keys to delete (optional) + session_id: the ID for the E2E session to delete (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + """ + + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + yield self.e2e_room_keys_handler.delete_room_keys( + user_id, version, room_id, session_id + ) + defer.returnValue((200, {})) + + +class RoomKeysNewVersionServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/version$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysNewVersionServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + """ + Create a new backup version for this user's room_keys with the given + info. The version is allocated by the server and returned to the user + in the response. This API is intended to be used whenever the user + changes the encryption key for their backups, ensuring that backups + encrypted with different keys don't collide. + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + The algorithm passed in the version info is a reverse-DNS namespaced + identifier to describe the format of the encrypted backupped keys. + + The auth_data is { user_id: "user_id", nonce: <random string> } + encrypted using the algorithm and current encryption key described above. + + POST /room_keys/version + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + HTTP/1.1 200 OK + Content-Type: application/json + { + "version": 12345 + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + new_version = yield self.e2e_room_keys_handler.create_version( + user_id, info + ) + defer.returnValue((200, {"version": new_version})) + + # we deliberately don't have a PUT /version, as these things really should + # be immutable to avoid people footgunning + + +class RoomKeysVersionServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/version(/(?P<version>[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysVersionServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_GET(self, request, version): + """ + Retrieve the version information about a given version of the user's + room_keys backup. If the version part is missing, returns info about the + most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version/12345 HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = yield self.e2e_room_keys_handler.get_version_info( + user_id, version + ) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + defer.returnValue((200, info)) + + @defer.inlineCallbacks + def on_DELETE(self, request, version): + """ + Delete the information about a given version of the user's + room_keys backup. If the version part is missing, deletes the most + current backup version (if any). Doesn't delete the actual room data. + + DELETE /room_keys/version/12345 HTTP/1.1 + HTTP/1.1 200 OK + {} + """ + if version is None: + raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) + + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + yield self.e2e_room_keys_handler.delete_version( + user_id, version + ) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + RoomKeysServlet(hs).register(http_server) + RoomKeysVersionServlet(hs).register(http_server) + RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 1275baa1ba..0251146722 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, + format_event_raw, serialize_event, ) from synapse.handlers.presence import format_user_presence_state @@ -88,7 +89,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - if "from" in request.args: + if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. raise SynapseError( @@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_response(time_now, sync_result, access_token_id, filter): + if filter.event_format == 'client': + event_formatter = format_event_for_client_v2_without_room_id + elif filter.event_format == 'federation': + event_formatter = format_event_raw + else: + raise Exception("Unknown event format %s" % (filter.event_format, )) + joined = SyncRestServlet.encode_joined( - sync_result.joined, time_now, access_token_id, filter.event_fields + sync_result.joined, time_now, access_token_id, + filter.event_fields, + event_formatter, ) invited = SyncRestServlet.encode_invited( sync_result.invited, time_now, access_token_id, + event_formatter, ) archived = SyncRestServlet.encode_archived( sync_result.archived, time_now, access_token_id, filter.event_fields, + event_formatter, ) return { @@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet): } @staticmethod - def encode_joined(rooms, time_now, token_id, event_fields): + def encode_joined(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the joined rooms in a sync result @@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, only_fields=event_fields + room, time_now, token_id, joined=True, only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_invited(rooms, time_now, token_id): + def encode_invited(rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet): time_now(int): current time - used as a baseline for age calculations token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs + of transaction IDs + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the invited rooms list, in our @@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet): for room in rooms: invite = serialize_event( room.invite, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, is_invite=True, ) unsigned = dict(invite.get("unsigned", {})) @@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet): return invited @staticmethod - def encode_archived(rooms, time_now, token_id, event_fields): + def encode_archived(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the archived rooms in a sync result @@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, joined=False, only_fields=event_fields + room, time_now, token_id, joined=False, + only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True, only_fields=None): + def encode_room( + room, time_now, token_id, joined, + only_fields, event_formatter, + ): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet): joined (bool): True if the user is joined to this room - will mean we handle ephemeral events only_fields(list<str>): Optional. The list of event fields to include. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, object]: the room, encoded in our response format """ def serialize(event): - # TODO(mjark): Respect formatting requirements in the filter. return serialize_event( event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, only_event_fields=only_fields, ) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d9d379182e..b9b5d07677 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields @@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index b9ee6e1c13..38eb2ee23f 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -88,5 +88,5 @@ class LocalKey(Resource): ) def getChild(self, name, request): - if name == '': + if name == b'': return self diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 3491fd2118..cb5abcf826 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey class KeyApiV2Resource(Resource): def __init__(self, hs): Resource.__init__(self) - self.putChild("server", LocalKey(hs)) - self.putChild("query", RemoteKey(hs)) + self.putChild(b"server", LocalKey(hs)) + self.putChild(b"query", RemoteKey(hs)) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7d67e4b064..eb8782aa6e 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,7 +103,7 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server: {}} + query = {server.decode('ascii'): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer( @@ -112,11 +112,12 @@ class RemoteKey(Resource): arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} else: raise SynapseError( 404, "Not found %r" % request.postpath, Codes.NOT_FOUND ) + yield self.query_keys(request, query, query_remote_on_cache_miss=True) def render_POST(self, request): @@ -135,6 +136,7 @@ class RemoteKey(Resource): @defer.inlineCallbacks def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.info("Handling query for keys %r", query) + store_queries = [] for server_name, key_ids in query.items(): if ( diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index f255f2883f..5a426ff2f6 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.split('/')[-1] + filename = request.path.decode('ascii').split('/')[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource): # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) request.setHeader( - "Cache-Control", "public,max-age=86400,s-maxage=86400" + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" ) d = FileSender().beginFileTransfer(f, request) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65f4bd2910..76e479afa3 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -15,9 +15,8 @@ import logging import os -import urllib -from six.moves.urllib import parse as urlparse +from six.moves import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -35,10 +34,15 @@ def parse_media_id(request): # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. server_name, media_id = request.postpath[:2] + + if isinstance(server_name, bytes): + server_name = server_name.decode('utf-8') + media_id = media_id.decode('utf8') + file_name = None if len(request.postpath) > 2: try: - file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") + file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) except UnicodeDecodeError: pass return server_name, media_id, file_name @@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name): file_size (int): Size in bytes of the media, if known. upload_name (str): The name of the requested file, if any. """ + def _quote(x): + return urllib.parse.quote(x.encode("utf-8")) + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) if upload_name: if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii") else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ( + "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii") + + request.setHeader(b"Content-Disposition", disposition) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index fbfa85f74f..f911b120b1 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -47,12 +47,13 @@ class DownloadResource(Resource): def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( - "Content-Security-Policy", - "default-src 'none';" - " script-src 'none';" - " plugin-types application/pdf;" - " style-src 'unsafe-inline';" - " object-src 'self';" + b"Content-Security-Policy", + b"default-src 'none';" + b" script-src 'none';" + b" plugin-types application/pdf;" + b" style-src 'unsafe-inline';" + b" media-src 'self';" + b" object-src 'self';" ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 241c972070..08b1867fab 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -20,12 +20,12 @@ import logging import os import shutil -from six import iteritems +from six import PY3, iteritems from six.moves.urllib import parse as urlparse import twisted.internet.error import twisted.web.http -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -36,8 +36,8 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import logcontext from synapse.util.async_helpers import Linearizer -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string @@ -397,13 +397,13 @@ class MediaRepository(object): yield finish() - media_type = headers["Content-Type"][0] + media_type = headers[b"Content-Type"][0].decode('ascii') time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) + _, params = cgi.parse_header(content_disposition[0].decode('ascii'),) upload_name = None # First check if there is a valid UTF-8 filename @@ -419,9 +419,13 @@ class MediaRepository(object): upload_name = upload_name_ascii if upload_name: - upload_name = urlparse.unquote(upload_name) + if PY3: + upload_name = urlparse.unquote(upload_name) + else: + upload_name = urlparse.unquote(upload_name.encode('ascii')) try: - upload_name = upload_name.decode("utf-8") + if isinstance(upload_name, bytes): + upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: @@ -488,10 +492,11 @@ class MediaRepository(object): )) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type - )) + ) if t_byte_source: try: @@ -530,10 +535,11 @@ class MediaRepository(object): )) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type - )) + ) if t_byte_source: try: @@ -616,15 +622,17 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type, - )) + ) elif t_method == "scale": - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type, - )) + ) else: logger.error("Unrecognized method: %r", t_method) continue @@ -755,14 +763,15 @@ class MediaRepositoryResource(Resource): Resource.__init__(self) media_repo = hs.get_media_repository() - self.putChild("upload", UploadResource(hs, media_repo)) - self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource( + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild(b"thumbnail", ThumbnailResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("identicon", IdenticonResource()) + self.putChild(b"identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource( + self.putChild(b"preview_url", PreviewUrlResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("config", MediaConfigResource(hs)) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a6189224ee..896078fe76 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -21,9 +21,10 @@ import sys import six -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.protocols.basic import FileSender +from synapse.util import logcontext from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.logcontext import make_deferred_yieldable @@ -64,9 +65,10 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield make_deferred_yieldable(threads.deferToThread( + yield logcontext.defer_to_thread( + self.hs.get_reactor(), _write_file_synchronously, source, f, - )) + ) yield finish_cb() defer.returnValue(fname) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 778ef97337..1a7bfd6b56 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -79,7 +79,6 @@ class PreviewUrlResource(Resource): # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self._cache.start() self._cleaner_loop = self.clock.looping_call( self._start_expire_url_cache_data, 10 * 1000, @@ -261,7 +260,7 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og) + jsonog = json.dumps(og).encode('utf8') # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -301,20 +300,20 @@ class PreviewUrlResource(Resource): logger.warn("Error downloading %s: %r", url, e) raise SynapseError( 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_type, e), + traceback.format_exception_only(sys.exc_info()[0], e), ), Codes.UNKNOWN, ) yield finish() try: - if "Content-Type" in headers: - media_type = headers["Content-Type"][0] + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode('ascii') else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) download_name = None @@ -597,10 +596,13 @@ def _iterate_over_text(tree, *tags_to_ignore): # to be returned. elements = iter([tree]) while True: - el = next(elements) + el = next(elements, None) + if el is None: + return + if isinstance(el, string_types): yield el - elif el is not None and el.tag not in tags_to_ignore: + elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately # return it if the text exists. if el.text: @@ -672,7 +674,7 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # This splits the paragraph into words, but keeping the # (preceeding) whitespace intact so we can easily concat # words back together. - for match in re.finditer("\s*\S+", description): + for match in re.finditer(r"\s*\S+", description): word = match.group() # Keep adding words while the total length is less than diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 7b9f8b4d79..5aa03031f6 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -17,9 +17,10 @@ import logging import os import shutil -from twisted.internet import defer, threads +from twisted.internet import defer from synapse.config._base import Config +from synapse.util import logcontext from synapse.util.logcontext import run_in_background from .media_storage import FileResponder @@ -120,7 +121,8 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return threads.deferToThread( + return logcontext.defer_to_thread( + self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname, ) diff --git a/synapse/server.py b/synapse/server.py index 26228d8c72..cf6b872cbd 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,6 +19,7 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation +import abc import logging from twisted.enterprise import adbapi @@ -50,13 +51,14 @@ from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.room import RoomContextHandler, RoomCreationHandler @@ -81,7 +83,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStore from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -111,6 +112,8 @@ class HomeServer(object): config (synapse.config.homeserver.HomeserverConfig): """ + __metaclass__ = abc.ABCMeta + DEPENDENCIES = [ 'http_client', 'db_pool', @@ -128,6 +131,7 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'e2e_room_keys_handler', 'event_handler', 'event_stream_handler', 'initial_sync_handler', @@ -172,6 +176,11 @@ class HomeServer(object): 'room_context_handler', ] + # This is overridden in derived application classes + # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be + # instantiated during setup() for future return by get_datastore() + DATASTORE_CLASS = abc.abstractproperty() + def __init__(self, hostname, reactor=None, **kwargs): """ Args: @@ -188,13 +197,17 @@ class HomeServer(object): self.distributor = Distributor() self.ratelimiter = Ratelimiter() + self.datastore = None + # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) def setup(self): logger.info("Setting up.") - self.datastore = DataStore(self.get_db_conn(), self) + with self.get_db_conn() as conn: + self.datastore = self.DATASTORE_CLASS(conn, self) + conn.commit() logger.info("Finished setting up.") def get_reactor(self): @@ -289,6 +302,9 @@ class HomeServer(object): def build_e2e_keys_handler(self): return E2eKeysHandler(self) + def build_e2e_room_keys_handler(self): + return E2eRoomKeysHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) @@ -308,7 +324,10 @@ class HomeServer(object): return InitialSyncHandler(self) def build_profile_handler(self): - return ProfileHandler(self) + if self.config.worker_app: + return BaseProfileHandler(self) + else: + return MasterProfileHandler(self) def build_event_creation_handler(self): return EventCreationHandler(self) diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py new file mode 100644 index 0000000000..af15cba0ee --- /dev/null +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.api.constants import ( + EventTypes, + ServerNoticeLimitReached, + ServerNoticeMsgType, +) +from synapse.api.errors import AuthError, ResourceLimitError, SynapseError +from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG + +logger = logging.getLogger(__name__) + + +class ResourceLimitsServerNotices(object): + """ Keeps track of whether the server has reached it's resource limit and + ensures that the client is kept up to date. + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + self._server_notices_manager = hs.get_server_notices_manager() + self._store = hs.get_datastore() + self._auth = hs.get_auth() + self._config = hs.config + self._resouce_limited = False + self._message_handler = hs.get_message_handler() + self._state = hs.get_state_handler() + + self._notifier = hs.get_notifier() + + @defer.inlineCallbacks + def maybe_send_server_notice_to_user(self, user_id): + """Check if we need to send a notice to this user, this will be true in + two cases. + 1. The server has reached its limit does not reflect this + 2. The room state indicates that the server has reached its limit when + actually the server is fine + + Args: + user_id (str): user to check + + Returns: + Deferred + """ + if self._config.hs_disabled is True: + return + + if self._config.limit_usage_by_mau is False: + return + + if not self._server_notices_manager.is_enabled(): + # Don't try and send server notices unles they've been enabled + return + + timestamp = yield self._store.user_last_seen_monthly_active(user_id) + if timestamp is None: + # This user will be blocked from receiving the notice anyway. + # In practice, not sure we can ever get here + return + + # Determine current state of room + + room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) + + if not room_id: + logger.warn("Failed to get server notices room") + return + + yield self._check_and_set_tags(user_id, room_id) + currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + + try: + # Normally should always pass in user_id if you have it, but in + # this case are checking what would happen to other users if they + # were to arrive. + try: + yield self._auth.check_auth_blocking() + is_auth_blocking = False + except ResourceLimitError as e: + is_auth_blocking = True + event_content = e.msg + event_limit_type = e.limit_type + + if currently_blocked and not is_auth_blocking: + # Room is notifying of a block, when it ought not to be. + # Remove block notification + content = { + "pinned": ref_events + } + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, '', + ) + + elif not currently_blocked and is_auth_blocking: + # Room is not notifying of a block, when it ought to be. + # Add block notification + content = { + 'body': event_content, + 'msgtype': ServerNoticeMsgType, + 'server_notice_type': ServerNoticeLimitReached, + 'admin_contact': self._config.admin_contact, + 'limit_type': event_limit_type + } + event = yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message, + ) + + content = { + "pinned": [ + event.event_id, + ] + } + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, '', + ) + + except SynapseError as e: + logger.error("Error sending resource limits server notice: %s", e) + + @defer.inlineCallbacks + def _check_and_set_tags(self, user_id, room_id): + """ + Since server notices rooms were originally not with tags, + important to check that tags have been set correctly + Args: + user_id(str): the user in question + room_id(str): the server notices room for that user + """ + tags = yield self._store.get_tags_for_room(user_id, room_id) + need_to_set_tag = True + if tags: + if SERVER_NOTICE_ROOM_TAG in tags: + # tag already present, nothing to do here + need_to_set_tag = False + if need_to_set_tag: + max_id = yield self._store.add_tag_to_room( + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} + ) + self._notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + @defer.inlineCallbacks + def _is_room_currently_blocked(self, room_id): + """ + Determines if the room is currently blocked + + Args: + room_id(str): The room id of the server notices room + + Returns: + + bool: Is the room currently blocked + list: The list of pinned events that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved + """ + currently_blocked = False + pinned_state_event = None + try: + pinned_state_event = yield self._state.get_current_state( + room_id, event_type=EventTypes.Pinned + ) + except AuthError: + # The user has yet to join the server notices room + pass + + referenced_events = [] + if pinned_state_event is not None: + referenced_events = list(pinned_state_event.content.get('pinned', [])) + + events = yield self._store.get_events(referenced_events) + for event_id, event in iteritems(events): + if event.type != EventTypes.Message: + continue + if event.content.get("msgtype") == ServerNoticeMsgType: + currently_blocked = True + # remove event in case we need to disable blocking later on. + if event_id in referenced_events: + referenced_events.remove(event.event_id) + + defer.returnValue((currently_blocked, referenced_events)) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index a26deace53..c5cc6d728e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) +SERVER_NOTICE_ROOM_TAG = "m.server_notice" + class ServerNoticesManager(object): def __init__(self, hs): @@ -37,6 +39,8 @@ class ServerNoticesManager(object): self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id + self._notifier = hs.get_notifier() + def is_enabled(self): """Checks if server notices are enabled on this server. @@ -46,7 +50,10 @@ class ServerNoticesManager(object): return self._config.server_notices_mxid is not None @defer.inlineCallbacks - def send_notice(self, user_id, event_content): + def send_notice( + self, user_id, event_content, + type=EventTypes.Message, state_key=None + ): """Send a notice to the given user Creates the server notices room, if none exists. @@ -54,9 +61,11 @@ class ServerNoticesManager(object): Args: user_id (str): mxid of user to send event to. event_content (dict): content of event to send + type(EventTypes): type of event + is_state_event(bool): Is the event a state event Returns: - Deferred[None] + Deferred[FrozenEvent] """ room_id = yield self.get_notice_room_for_user(user_id) @@ -65,15 +74,20 @@ class ServerNoticesManager(object): logger.info("Sending server notice to %s", user_id) - yield self._event_creation_handler.create_and_send_nonmember_event( - requester, { - "type": EventTypes.Message, - "room_id": room_id, - "sender": system_mxid, - "content": event_content, - }, - ratelimit=False, + event_dict = { + "type": type, + "room_id": room_id, + "sender": system_mxid, + "content": event_content, + } + + if state_key is not None: + event_dict['state_key'] = state_key + + res = yield self._event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, ratelimit=False, ) + defer.returnValue(res) @cachedInlineCallbacks() def get_notice_room_for_user(self, user_id): @@ -142,5 +156,12 @@ class ServerNoticesManager(object): ) room_id = info['room_id'] + max_id = yield self._store.add_tag_to_room( + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}, + ) + self._notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + logger.info("Created server notices room %s for %s", room_id, user_id) defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 5d23965f34..6121b2f267 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,7 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.server_notices.consent_server_notices import ConsentServerNotices +from synapse.server_notices.resource_limits_server_notices import ( + ResourceLimitsServerNotices, +) class ServerNoticesSender(object): @@ -25,34 +30,34 @@ class ServerNoticesSender(object): Args: hs (synapse.server.HomeServer): """ - # todo: it would be nice to make this more dynamic - self._consent_server_notices = ConsentServerNotices(hs) + self._server_notices = ( + ConsentServerNotices(hs), + ResourceLimitsServerNotices(hs) + ) + @defer.inlineCallbacks def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced - - Returns: - Deferred """ - return self._consent_server_notices.maybe_send_server_notice_to_user( - user_id, - ) + for sn in self._server_notices: + yield sn.maybe_send_server_notice_to_user( + user_id, + ) + @defer.inlineCallbacks def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: user_id (str): mxid - - Returns: - Deferred """ # The synchrotrons use a stubbed version of ServerNoticesSender, so # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing - return self._consent_server_notices.maybe_send_server_notice_to_user( - user_id, - ) + for sn in self._server_notices: + yield sn.maybe_send_server_notice_to_user( + user_id, + ) diff --git a/synapse/state.py b/synapse/state/__init__.py index 0f2bedb694..9b40b18d5b 100644 --- a/synapse/state.py +++ b/synapse/state/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import hashlib import logging from collections import namedtuple -from six import iteritems, iterkeys, itervalues +from six import iteritems, itervalues +import attr from frozendict import frozendict from twisted.internet import defer -from synapse import event_auth -from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError +from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext +from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -97,10 +96,6 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - def start_caching(self): - # TODO: remove this shim - self._state_resolution_handler.start_caching() - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", latest_event_ids=None): @@ -264,6 +259,7 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], ) @@ -338,8 +334,11 @@ class StateHandler(object): event, resolves conflicts between them and returns them. Args: - room_id (str): - event_ids (list[str]): + room_id (str) + event_ids (list[str]) + explicit_room_version (str|None): If set uses the the given room + version to choose the resolution algorithm. If None, then + checks the database for room version. Returns: Deferred[_StateCacheEntry]: resolved state @@ -353,7 +352,12 @@ class StateHandler(object): room_id, event_ids ) - if len(state_groups_ids) == 1: + if len(state_groups_ids) == 0: + defer.returnValue(_StateCacheEntry( + state={}, + state_group=None, + )) + elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) @@ -365,17 +369,16 @@ class StateHandler(object): delta_ids=delta_ids, )) + room_version = yield self.store.get_room_version(room_id) + result = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups_ids, None, self._state_map_factory, + room_id, room_version, state_groups_ids, None, + state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) - def _state_map_factory(self, ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - - def resolve_events(self, state_sets, event): + @defer.inlineCallbacks + def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -391,13 +394,17 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state = yield resolve_events_with_store( + room_version, state_set_ids, + event_map=state_map, + state_res_store=StateResolutionStore(self.store), + ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) } - return new_state + defer.returnValue(new_state) class StateResolutionHandler(object): @@ -413,9 +420,6 @@ class StateResolutionHandler(object): self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") - def start_caching(self): - logger.debug("start_caching") - self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -425,12 +429,10 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - self._state_cache.start() - @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_res_store, ): """Resolves conflicts between a set of state groups @@ -439,6 +441,7 @@ class StateResolutionHandler(object): Args: room_id (str): room we are resolving for (used for logging) + room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) @@ -447,9 +450,11 @@ class StateResolutionHandler(object): a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. + events will be requested via state_res_store. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. + + state_res_store (StateResolutionStore) Returns: Deferred[_StateCacheEntry]: resolved state @@ -473,10 +478,10 @@ class StateResolutionHandler(object): # start by assuming we won't have any conflicted state, and build up the new # state map by iterating through the state groups. If we discover a conflict, - # we give up and instead use `resolve_events_with_factory`. + # we give up and instead use `resolve_events_with_store`. # # XXX: is this actually worthwhile, or should we just let - # resolve_events_with_factory do it? + # resolve_events_with_store do it? new_state = {} conflicted_state = False for st in itervalues(state_groups_ids): @@ -491,10 +496,11 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_factory( + new_state = yield resolve_events_with_store( + room_version, list(itervalues(state_groups_ids)), event_map=event_map, - state_map_factory=state_map_factory, + state_res_store=state_res_store, ) # if the new state matches any of the input state groups, we can @@ -575,94 +581,11 @@ def _make_state_cache_entry( ) -def _ordered_events(events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() - - return sorted(events, key=key_func) - - -def resolve_events_with_state_map(state_sets, state_map): - """ - Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - state_map(dict): a dict from event_id to event, for all events in - state_sets. - - Returns - dict[(str, str), str]: - a map from (type, state_key) to event_id. +def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): """ - if len(state_sets) == 1: - return state_sets[0] - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - ) - - -def _seperate(state_sets): - """Takes the state_sets and figures out which keys are conflicted and - which aren't. i.e., which have multiple different event_ids associated - with them in different state sets. - Args: - state_sets(iterable[dict[(str, str), str]]): - List of dicts of (type, state_key) -> event_id, which are the - different state groups to resolve. - - Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: - - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. - - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. - """ - state_set_iterator = iter(state_sets) - unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} - - for state_set in state_set_iterator: - for key, value in iteritems(state_set): - # Check if there is an unconflicted entry for the state key. - unconflicted_value = unconflicted_state.get(key) - if unconflicted_value is None: - # There isn't an unconflicted entry so check if there is a - # conflicted entry. - ls = conflicted_state.get(key) - if ls is None: - # There wasn't a conflicted entry so haven't seen this key before. - # Therefore it isn't conflicted yet. - unconflicted_state[key] = value - else: - # This key is already conflicted, add our value to the conflict set. - ls.add(value) - elif unconflicted_value != value: - # If the unconflicted value is not the same as our value then we - # have a new conflict. So move the key from the unconflicted_state - # to the conflicted state. - conflicted_state[key] = {value, unconflicted_value} - unconflicted_state.pop(key, None) - - return unconflicted_state, conflicted_state - + room_version(str): Version of the room -@defer.inlineCallbacks -def resolve_events_with_factory(state_sets, event_map, state_map_factory): - """ - Args: state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. @@ -674,193 +597,74 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + state_res_store (StateResolutionStore) Returns Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if len(state_sets) == 1: - defer.returnValue(state_sets[0]) - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids - ) - if event_map is not None: - needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d conflicted events", len(needed_events)) - - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) - if event_map is not None: - state_map.update(event_map) - - # get the ids of the auth events which allow us to authenticate the - # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - new_needed_events = set(itervalues(auth_events)) - new_needed_events -= needed_events - if event_map is not None: - new_needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d auth events", len(new_needed_events)) - - state_map_new = yield state_map_factory(new_needed_events) - state_map.update(state_map_new) - - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) - - -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): - auth_events = {} - for event_ids in itervalues(conflicted_state): - for event_id in event_ids: - if event_id in state_map: - keys = event_auth.auth_types_for_event(state_map[event_id]) - for key in keys: - if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id - return auth_events - - -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): - conflicted_state = {} - for key, event_ids in iteritems(conflicted_state_ids): - events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] - if len(events) > 1: - conflicted_state[key] = events - elif len(events) == 1: - unconflicted_state_ids[key] = events[0].event_id - - auth_events = { - key: state_map[ev_id] - for key, ev_id in iteritems(auth_event_ids) - if ev_id in state_map - } - - try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events + if room_version == RoomVersions.V1: + return v1.resolve_events_with_store( + state_sets, event_map, state_res_store.get_events, + ) + elif room_version == RoomVersions.VDH_TEST: + return v2.resolve_events_with_store( + state_sets, event_map, state_res_store, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) ) - except Exception: - logger.exception("Failed to resolve state") - raise - new_state = unconflicted_state_ids - for key, event in iteritems(resolved_state): - new_state[key] = event.event_id - return new_state +@attr.s +class StateResolutionStore(object): + """Interface that allows state resolution algorithms to access the database + in well defined way. + Args: + store (DataStore) + """ -def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. + store = attr.ib() - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - if POWER_KEY in conflicted_state: - events = conflicted_state[POWER_KEY] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + def get_events(self, event_ids, allow_rejected=False): + """Get events from the database - auth_events.update(resolved_state) + Args: + event_ids (list): The event_ids of the events to fetch + allow_rejected (bool): If True return rejected events. - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + Returns: + Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + """ - auth_events.update(resolved_state) + return self.store.get_events( + event_ids, + check_redacted=False, + get_prev_content=False, + allow_rejected=allow_rejected, + ) - for key, events in iteritems(conflicted_state): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) + def get_auth_chain(self, event_ids): + """Gets the full auth chain for a set of events (including rejected + events). - return resolved_state + Includes the given event IDs in the result. + Note that: + 1. All events must be state events. + 2. For v1 rooms this may not have the full auth chain in the + presence of rejected events -def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + Args: + event_ids (list): The event IDs of the events to fetch the auth + chain for. Must be state events. - auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) - ) + Returns: + Deferred[list[str]]: List of event IDs of the auth chain. + """ - new_auth_events = {} - for key in auth_keys: - auth_event = auth_events.get(key, None) - if auth_event: - new_auth_events[key] = auth_event - - auth_events = new_auth_events - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - -def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event + return self.store.get_auth_chain_ids(event_ids, include_given=True) diff --git a/synapse/state/v1.py b/synapse/state/v1.py new file mode 100644 index 0000000000..70a981f4a2 --- /dev/null +++ b/synapse/state/v1.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import logging + +from six import iteritems, iterkeys, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +POWER_KEY = (EventTypes.PowerLevels, "") + + +@defer.inlineCallbacks +def resolve_events_with_store(state_sets, event_map, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + needed_events = set( + event_id + for event_ids in itervalues(conflicted_state) + for event_id in event_ids + ) + needed_event_count = len(needed_events) + if event_map is not None: + needed_events -= set(iterkeys(event_map)) + + logger.info( + "Asking for %d/%d conflicted events", + len(needed_events), + needed_event_count, + ) + + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) + state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(itervalues(auth_events)) + new_needed_event_count = len(new_needed_events) + new_needed_events -= needed_events + if event_map is not None: + new_needed_events -= set(iterkeys(event_map)) + + logger.info( + "Asking for %d/%d auth events", + len(new_needed_events), + new_needed_event_count, + ) + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + + Args: + state_sets(iterable[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. + """ + state_set_iterator = iter(state_sets) + unconflicted_state = dict(next(state_set_iterator)) + conflicted_state = {} + + for state_set in state_set_iterator: + for key, value in iteritems(state_set): + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in itervalues(conflicted_state): + for event_id in event_ids: + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, + state_map): + conflicted_state = {} + for key, event_ids in iteritems(conflicted_state_ids): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state_ids[key] = events[0].event_id + + auth_events = { + key: state_map[ev_id] + for key, ev_id in iteritems(auth_event_ids) + if ev_id in state_map + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except Exception: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state_ids + for key, event in iteritems(resolved_state): + new_state[key] = event.event_id + + return new_state + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[POWER_KEY] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() + + return sorted(events, key=key_func) diff --git a/synapse/state/v2.py b/synapse/state/v2.py new file mode 100644 index 0000000000..5d06f7e928 --- /dev/null +++ b/synapse/state/v2.py @@ -0,0 +1,544 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +import itertools +import logging + +from six import iteritems, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def resolve_events_with_store(state_sets, event_map, state_res_store): + """Resolves the state using the v2 state resolution algorithm + + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_res_store. + + If None, all events will be fetched via state_res_store. + + state_res_store (StateResolutionStore) + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + + logger.debug("Computing conflicted state") + + # First split up the un/conflicted state + unconflicted_state, conflicted_state = _seperate(state_sets) + + if not conflicted_state: + defer.returnValue(unconflicted_state) + + logger.debug("%d conflicted state entries", len(conflicted_state)) + logger.debug("Calculating auth chain difference") + + # Also fetch all auth events that appear in only some of the state sets' + # auth chains. + auth_diff = yield _get_auth_chain_difference( + state_sets, event_map, state_res_store, + ) + + full_conflicted_set = set(itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), + auth_diff, + )) + + events = yield state_res_store.get_events([ + eid for eid in full_conflicted_set + if eid not in event_map + ], allow_rejected=True) + event_map.update(events) + + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + + logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) + + # Get and sort all the power events (kicks/bans/etc) + power_events = ( + eid for eid in full_conflicted_set + if _is_power_event(event_map[eid]) + ) + + sorted_power_events = yield _reverse_topological_power_sort( + power_events, + event_map, + state_res_store, + full_conflicted_set, + ) + + logger.debug("sorted %d power events", len(sorted_power_events)) + + # Now sequentially auth each one + resolved_state = yield _iterative_auth_checks( + sorted_power_events, unconflicted_state, event_map, + state_res_store, + ) + + logger.debug("resolved power events") + + # OK, so we've now resolved the power events. Now sort the remaining + # events using the mainline of the resolved power level. + + leftover_events = [ + ev_id + for ev_id in full_conflicted_set + if ev_id not in sorted_power_events + ] + + logger.debug("sorting %d remaining events", len(leftover_events)) + + pl = resolved_state.get((EventTypes.PowerLevels, ""), None) + leftover_events = yield _mainline_sort( + leftover_events, pl, event_map, state_res_store, + ) + + logger.debug("resolving remaining events") + + resolved_state = yield _iterative_auth_checks( + leftover_events, resolved_state, event_map, + state_res_store, + ) + + logger.debug("resolved") + + # We make sure that unconflicted state always still applies. + resolved_state.update(unconflicted_state) + + logger.debug("done") + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _get_power_level_for_sender(event_id, event_map, state_res_store): + """Return the power level of the sender of the given event according to + their auth events. + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + event = yield _get_event(event_id, event_map, state_res_store) + + pl = None + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + pl = aev + break + + if pl is None: + # Couldn't find power level. Check if they're the creator of the room + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.Create, ""): + if aev.content.get("creator") == event.sender: + defer.returnValue(100) + break + defer.returnValue(0) + + level = pl.content.get("users", {}).get(event.sender) + if level is None: + level = pl.content.get("users_default", 0) + + if level is None: + defer.returnValue(0) + else: + defer.returnValue(int(level)) + + +@defer.inlineCallbacks +def _get_auth_chain_difference(state_sets, event_map, state_res_store): + """Compare the auth chains of each state set and return the set of events + that only appear in some but not all of the auth chains. + + Args: + state_sets (list) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[set[str]]: Set of event IDs + """ + common = set(itervalues(state_sets[0])).intersection( + *(itervalues(s) for s in state_sets[1:]) + ) + + auth_sets = [] + for state_set in state_sets: + auth_ids = set( + eid + for key, eid in iteritems(state_set) + if (key[0] in ( + EventTypes.Member, + EventTypes.ThirdPartyInvite, + ) or key in ( + (EventTypes.PowerLevels, ''), + (EventTypes.Create, ''), + (EventTypes.JoinRules, ''), + )) and eid not in common + ) + + auth_chain = yield state_res_store.get_auth_chain(auth_ids) + auth_ids.update(auth_chain) + + auth_sets.append(auth_ids) + + intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) + union = set().union(*auth_sets) + + defer.returnValue(union - intersection) + + +def _seperate(state_sets): + """Return the unconflicted and conflicted state. This is different than in + the original algorithm, as this defines a key to be conflicted if one of + the state sets doesn't have that key. + + Args: + state_sets (list) + + Returns: + tuple[dict, dict]: A tuple of unconflicted and conflicted state. The + conflicted state dict is a map from type/state_key to set of event IDs + """ + unconflicted_state = {} + conflicted_state = {} + + for key in set(itertools.chain.from_iterable(state_sets)): + event_ids = set(state_set.get(key) for state_set in state_sets) + if len(event_ids) == 1: + unconflicted_state[key] = event_ids.pop() + else: + event_ids.discard(None) + conflicted_state[key] = event_ids + + return unconflicted_state, conflicted_state + + +def _is_power_event(event): + """Return whether or not the event is a "power event", as defined by the + v2 state resolution algorithm + + Args: + event (FrozenEvent) + + Returns: + boolean + """ + if (event.type, event.state_key) in ( + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Create, ""), + ): + return True + + if event.type == EventTypes.Member: + if event.membership in ('leave', 'ban'): + return event.sender != event.state_key + + return False + + +@defer.inlineCallbacks +def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, + state_res_store, auth_diff): + """Helper function for _reverse_topological_power_sort that add the event + and its auth chain (that is in the auth diff) to the graph + + Args: + graph (dict[str, set[str]]): A map from event ID to the events auth + event IDs + event_id (str): Event to add to the graph + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + """ + + state = [event_id] + while state: + eid = state.pop() + graph.setdefault(eid, set()) + + event = yield _get_event(eid, event_map, state_res_store) + for aid, _ in event.auth_events: + if aid in auth_diff: + if aid not in graph: + state.append(aid) + + graph.setdefault(eid, set()).add(aid) + + +@defer.inlineCallbacks +def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): + """Returns a list of the event_ids sorted by reverse topological ordering, + and then by power level and origin_server_ts + + Args: + event_ids (list[str]): The events to sort + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + + Returns: + Deferred[list[str]]: The sorted list + """ + + graph = {} + for event_id in event_ids: + yield _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff, + ) + + event_to_pl = {} + for event_id in graph: + pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + event_to_pl[event_id] = pl + + def _get_power_order(event_id): + ev = event_map[event_id] + pl = event_to_pl[event_id] + + return -pl, ev.origin_server_ts, event_id + + # Note: graph is modified during the sort + it = lexicographical_topological_sort( + graph, + key=_get_power_order, + ) + sorted_events = list(it) + + defer.returnValue(sorted_events) + + +@defer.inlineCallbacks +def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): + """Sequentially apply auth checks to each event in given list, updating the + state as it goes along. + + Args: + event_ids (list[str]): Ordered list of events to apply auth checks to + base_state (dict[tuple[str, str], str]): The set of state to start with + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[dict[tuple[str, str], str]]: Returns the final updated state + """ + resolved_state = base_state.copy() + + for event_id in event_ids: + event = event_map[event_id] + + auth_events = {} + for aid, _ in event.auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[(ev.type, ev.state_key)] = ev + + for key in event_auth.auth_types_for_event(event): + if key in resolved_state: + ev_id = resolved_state[key] + ev = yield _get_event(ev_id, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[key] = event_map[ev_id] + + try: + event_auth.check( + event, auth_events, + do_sig_check=False, + do_size_check=False + ) + + resolved_state[(event.type, event.state_key)] = event_id + except AuthError: + pass + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _mainline_sort(event_ids, resolved_power_event_id, event_map, + state_res_store): + """Returns a sorted list of event_ids sorted by mainline ordering based on + the given event resolved_power_event_id + + Args: + event_ids (list[str]): Events to sort + resolved_power_event_id (str): The final resolved power level event ID + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[list[str]]: The sorted list + """ + mainline = [] + pl = resolved_power_event_id + while pl: + mainline.append(pl) + pl_ev = yield _get_event(pl, event_map, state_res_store) + auth_events = pl_ev.auth_events + pl = None + for aid, _ in auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): + pl = aid + break + + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} + + event_ids = list(event_ids) + + order_map = {} + for ev_id in event_ids: + depth = yield _get_mainline_depth_for_event( + event_map[ev_id], mainline_map, + event_map, state_res_store, + ) + order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + + event_ids.sort(key=lambda ev_id: order_map[ev_id]) + + defer.returnValue(event_ids) + + +@defer.inlineCallbacks +def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): + """Get the mainline depths for the given event based on the mainline map + + Args: + event (FrozenEvent) + mainline_map (dict[str, int]): Map from event_id to mainline depth for + events in the mainline. + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + + # We do an iterative search, replacing `event with the power level in its + # auth events (if any) + while event: + depth = mainline_map.get(event.event_id) + if depth is not None: + defer.returnValue(depth) + + auth_events = event.auth_events + event = None + + for aid, _ in auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + event = aev + break + + # Didn't find a power level auth event, so we just return 0 + defer.returnValue(0) + + +@defer.inlineCallbacks +def _get_event(event_id, event_map, state_res_store): + """Helper function to look up event in event_map, falling back to looking + it up in the store + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[FrozenEvent] + """ + if event_id not in event_map: + events = yield state_res_store.get_events([event_id], allow_rejected=True) + event_map.update(events) + defer.returnValue(event_map[event_id]) + + +def lexicographical_topological_sort(graph, key): + """Performs a lexicographic reverse topological sort on the graph. + + This returns a reverse topological sort (i.e. if node A references B then B + appears before A in the sort), with ties broken lexicographically based on + return value of the `key` function. + + NOTE: `graph` is modified during the sort. + + Args: + graph (dict[str, set[str]]): A representation of the graph where each + node is a key in the dict and its value are the nodes edges. + key (func): A function that takes a node and returns a value that is + comparable and used to order nodes + + Yields: + str: The next node in the topological sort + """ + + # Note, this is basically Kahn's algorithm except we look at nodes with no + # outgoing edges, c.f. + # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + outdegree_map = graph + reverse_graph = {} + + # Lists of nodes with zero out degree. Is actually a tuple of + # `(key(node), node)` so that sorting does the right thing + zero_outdegree = [] + + for node, edges in iteritems(graph): + if len(edges) == 0: + zero_outdegree.append((key(node), node)) + + reverse_graph.setdefault(node, set()) + for edge in edges: + reverse_graph.setdefault(edge, set()).add(node) + + # heapq is a built in implementation of a sorted queue. + heapq.heapify(zero_outdegree) + + while zero_outdegree: + _, node = heapq.heappop(zero_outdegree) + + for parent in reverse_graph[node]: + out = outdegree_map[parent] + out.discard(node) + if len(out) == 0: + heapq.heappush(zero_outdegree, (key(parent), parent)) + + yield node diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 23b4a8d76d..53c685c173 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .engines import PostgresEngine from .event_federation import EventFederationStore @@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, + EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 08dffd774f..d9d0255d0b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,9 +17,10 @@ import sys import threading import time -from six import iteritems, iterkeys, itervalues -from six.moves import intern, range +from six import PY2, iteritems, iterkeys, itervalues +from six.moves import builtins, intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, builtins.buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8fc678fa67..9ad17b7c25 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": last_seen, - }, - lock=False, - ) + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) @defer.inlineCallbacks def get_last_client_ip_by_device(self, user_id, device_id): diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 73646da025..e06b0bc56d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} - devices = messages_by_device.keys() + devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c0943ecf91..d10ff9e4b9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,7 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore +from ._base import Cache, SQLBaseStore, db_to_json logger = logging.getLogger(__name__) @@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore): if device is not None: key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(json.loads(content)) + defer.returnValue(db_to_json(content)) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore): desc="_get_cached_devices_for_user", ) defer.returnValue({ - device["device_id"]: json.loads(device["content"]) + device["device_id"]: db_to_json(device["content"]) for device in devices }) @@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 808194236a..61a029a53c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): }, retcol="creator", desc="get_room_alias_creator", - allow_none=True ) @cached(max_entries=5000) @@ -91,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an associatin between a room alias and room_id/servers + """ Creates an association between a room alias and room_id/servers Args: room_alias (RoomAlias) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py new file mode 100644 index 0000000000..f25ded2295 --- /dev/null +++ b/synapse/storage/e2e_room_keys.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError + +from ._base import SQLBaseStore + + +class EndToEndRoomKeyStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_e2e_room_key(self, user_id, version, room_id, session_id): + """Get the encrypted E2E room key for a given session from a given + backup version of room_keys. We only store the 'best' room key for a given + session at a given time, as determined by the handler. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): the ID of the room whose keys we're querying. + This is a bit redundant as it's implied by the session_id, but + we include for consistency with the rest of the API. + session_id(str): the session whose room_key we're querying. + + Returns: + A deferred dict giving the session_data and message metadata for + this room key. + """ + + row = yield self._simple_select_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_key", + ) + + row["session_data"] = json.loads(row["session_data"]) + + defer.returnValue(row) + + @defer.inlineCallbacks + def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces or inserts the encrypted E2E room key for a given session in + a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + Raises: + StoreError + """ + + yield self._simple_upsert( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + }, + values={ + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": json.dumps(room_key['session_data']), + }, + lock=False, + ) + + @defer.inlineCallbacks + def get_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + rows = yield self._simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "user_id", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {'rooms': {}} + for row in rows: + room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) + room_entry['sessions'][row['session_id']] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + "is_verified": row["is_verified"], + "session_data": json.loads(row["session_data"]), + } + + defer.returnValue(sessions) + + @defer.inlineCallbacks + def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + yield self._simple_delete( + table="e2e_room_keys", + keyvalues=keyvalues, + desc="delete_e2e_room_keys", + ) + + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", + (user_id,) + ) + row = txn.fetchone() + if not row: + raise StoreError(404, 'No current backup version') + return row[0] + + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. + + Args: + user_id(str): the user whose backup we're querying + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present + Returns: + A deferred dict giving the info metadata for this backup version + """ + + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + result = self._simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=( + "version", + "algorithm", + "auth_data", + ), + ) + result["auth_data"] = json.loads(result["auth_data"]) + return result + + return self.runInteraction( + "get_e2e_room_keys_version_info", + _get_e2e_room_keys_version_info_txn + ) + + def create_e2e_room_keys_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID + """ + + def _create_e2e_room_keys_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,) + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = '0' + + new_version = str(int(current_version) + 1) + + self._simple_insert_txn( + txn, + table="e2e_room_keys_versions", + values={ + "user_id": user_id, + "version": new_version, + "algorithm": info["algorithm"], + "auth_data": json.dumps(info["auth_data"]), + }, + ) + + return new_version + + return self.runInteraction( + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn + ) + + def delete_e2e_room_keys_version(self, user_id, version=None): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. + """ + + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + return self._simple_update_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + }, + updatevalues={ + "deleted": 1, + } + ) + + return self.runInteraction( + "delete_e2e_room_keys_version", + _delete_e2e_room_keys_version_txn + ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 523b4360c3..1f1721e820 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,13 +14,13 @@ # limitations under the License. from six import iteritems -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class EndToEndKeyStore(SQLBaseStore): @@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = json.loads(device_info.pop("key_json")) + device_info["keys"] = db_to_json(device_info.pop("key_json")) defer.returnValue(results) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 8a0386c1a4..42225f8a2a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -41,13 +41,18 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + + # Set the bytea output to escape, vs the default of hex + cursor = db_conn.cursor() + cursor.execute("SET bytea_output TO escape") + # Asynchronous commit, don't wait for the server to call fsync before # ending the transaction. # https://www.postgresql.org/docs/current/static/wal-async-commit.html if not self.synchronous_commit: - cursor = db_conn.cursor() cursor.execute("SET synchronous_commit TO OFF") - cursor.close() + + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 24345b20a6..3faca2a042 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -376,33 +376,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, - limit, min_depth): + limit): ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, - room_id, earliest_events, latest_events, limit, min_depth + room_id, earliest_events, latest_events, limit, ) - events = yield self._get_events(ids) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - defer.returnValue(events[:limit]) + defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, - limit, min_depth): - - earliest_events = set(earliest_events) - front = set(latest_events) - earliest_events + limit): - event_results = set() + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] query = ( "SELECT prev_event_id FROM event_edges " - "WHERE event_id = ? AND is_state = ? " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -411,18 +403,20 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, for event_id in front: txn.execute( query, - (event_id, False, limit - len(event_results)) + (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn: - new_front.add(e_id) + new_results = set(t[0] for t in txn) - seen_events - new_front -= earliest_events - new_front -= event_results + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) front = new_front - event_results |= new_front + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() return event_results diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 025a7fb6d9..8881b009df 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict, deque, namedtuple from functools import wraps -from six import iteritems +from six import iteritems, text_type from six.moves import range from canonicaljson import json @@ -34,10 +34,12 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder @@ -386,12 +388,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) for room_id, ev_ctx_rm in iteritems(events_by_room): - # Work out new extremities by recursively adding and removing - # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremeties( + new_latest_event_ids = yield self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -400,6 +400,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # No change in extremities, so no change in state continue + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + new_forward_extremeties[room_id] = new_latest_event_ids len_1 = ( @@ -517,44 +523,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremeties for a room given events to + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - new_latest_event_ids = set(latest_event_ids) - # First, add all the new events to the list - new_latest_event_ids.update( - event.event_id for event, ctx in event_contexts + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event for event, ctx in event_contexts if not event.internal_metadata.is_outlier() and not ctx.rejected + ] + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update( + event.event_id for event in new_events ) - # Now remove all events that are referenced by the to-be-added events - new_latest_event_ids.difference_update( + + # Now remove all events which are prev_events of any of the new events + result.difference_update( e_id - for event, ctx in event_contexts + for event in new_events for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() and not ctx.rejected ) - # And finally remove any events that are referenced by previously added - # events. - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=list(new_latest_event_ids), - retcols=["prev_event_id"], - keyvalues={ - "is_state": False, - }, - desc="_calculate_new_extremeties", - ) + # Finally, remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) - new_latest_event_ids.difference_update( - row["prev_event_id"] for row in rows - ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events(txn, batch): + sql = """ + SELECT prev_event_id + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + AND rejections.event_id IS NULL + """ % ( + ",".join("?" for _ in batch), + ) + + txn.execute(sql, batch) + results.extend(r[0] for r in txn) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events, + chunk, + ) - defer.returnValue(new_latest_event_ids) + defer.returnValue(results) @defer.inlineCallbacks def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, @@ -586,10 +627,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore the new current state is only returned if we've already calculated it. """ - - if not new_latest_event_ids: - return - # map from state_group to ((type, key) -> event_id) state map state_groups_map = {} @@ -695,19 +732,17 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Ok, we need to defer to the state handler to resolve our state sets. - def get_events(ev_ids): - return self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - state_groups = { sg: state_groups_map[sg] for sg in new_state_groups } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, + state_res_store=StateResolutionStore(self) ) defer.returnValue((res.state, None)) @@ -816,6 +851,27 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self._simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id, _ in event.auth_events + if event.is_state() + ], + ) + # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( @@ -928,6 +984,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) self._invalidate_cache_and_stream( + txn, self.get_room_summary, (room_id,) + ) + + self._invalidate_cache_and_stream( txn, self.get_current_state_ids, (room_id,) ) @@ -1218,7 +1278,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "sender": event.sender, "contains_url": ( "url" in event.content - and isinstance(event.content["url"], basestring) + and isinstance(event.content["url"], text_type) ), } for event, _ in events_and_contexts @@ -1287,21 +1347,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore txn, event.room_id, event.redacts ) - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id, _ in event.auth_events - if event.is_state() - ], - ) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1527,7 +1572,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], basestring) + contains_url &= isinstance(content["url"], text_type) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -1884,20 +1929,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ")" ) - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)", - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute( - "CREATE INDEX events_to_purge_id" - " ON events_to_purge(event_id)", - ) - # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -1908,9 +1939,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (room_id,) ) rows = txn.fetchall() - max_depth = max(row[0] for row in rows) + max_depth = max(row[1] for row in rows) - if max_depth <= token.topological: + if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) @@ -1924,19 +1955,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore should_delete_params = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" - should_delete_params += ("%:" + self.hs.hostname, ) + + # We include the parameter twice since we use the expression twice + should_delete_params += ( + "%:" + self.hs.hostname, + "%:" + self.hs.hostname, + ) should_delete_params += (room_id, token.topological) + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. txn.execute( "INSERT INTO events_to_purge" " SELECT event_id, %s" " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE e.room_id = ? AND topological_ordering < ?" % ( + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % ( + should_delete_expr, should_delete_expr, ), should_delete_params, ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute( + "CREATE INDEX events_to_purge_id" + " ON events_to_purge(event_id)", + ) + txn.execute( "SELECT event_id, should_delete FROM events_to_purge" ) @@ -2032,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], types=None + txn, [sg], ) curr_state = curr_state[sg] diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 59822178ff..a8326f5296 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import logging from collections import namedtuple @@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = zip(*event_list)[0] + event_id_lists = list(zip(*event_list))[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] @@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs): + def fire(evs, exc): for _, d in evs: if not d.called: with PreserveLoggingContext(): - d.errback(e) + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 2d5896c5b4..6ddcc909bf 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) + defer.returnValue(db_to_json(def_json)) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index f547977600..8af17921e3 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview @@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore): """ key_id = "%s:%s" % (verify_key.alg, verify_key.version) + # XXX fix this to not need a lock (#3819) def _txn(txn): self._simple_upsert_txn( txn, diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 06f9a75a97..cf4104dc2e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,20 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () + # Do not add more reserved users than the total allowable number + self._initialise_reserved_users( + dbconn.cursor(), + hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], + ) - @defer.inlineCallbacks - def initialise_reserved_users(self, threepids): - # TODO Why can't I do this in init? - store = self.hs.get_datastore() + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ reserved_user_list = [] - # Do not add more reserved users than the total allowable number - for tp in threepids[:self.hs.config.max_mau_value]: - user_id = yield store.get_user_id_by_threepid( + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn( + txn, tp["medium"], tp["address"] ) if user_id: - yield self.upsert_monthly_active_user(user_id) + self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: logger.warning( @@ -56,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def reap_monthly_active_users(self): - """ - Cleans out monthly active user table to ensure that no stale + """Cleans out monthly active user table to ensure that no stale entries exist. Returns: @@ -147,16 +155,63 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) + @defer.inlineCallbacks + def get_registered_reserved_users_count(self): + """Of the reserved threepids defined in config, how many are associated + with registered users? + + Returns: + Defered[int]: Number of real reserved users + """ + count = 0 + for tp in self.hs.config.mau_limits_reserved_threepids: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + count = count + 1 + defer.returnValue(count) + + @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update """ - Updates or inserts monthly active user member - Arguments: - user_id (str): user to add/update - Deferred[bool]: True if a new entry was created, False if an - existing one was updated. + is_insert = yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, + user_id + ) + + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. """ - is_insert = self._simple_upsert( - desc="upsert_monthly_active_user", + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self._simple_upsert_txn( + txn, table="monthly_active_users", keyvalues={ "user_id": user_id, @@ -164,11 +219,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={ "timestamp": int(self._clock.time_msec()), }, - lock=False, ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + + return is_insert @cached(num_args=1) def user_last_seen_monthly_active(self, user_id): @@ -199,7 +252,16 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ + if self.hs.config.limit_usage_by_mau: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return + is_trial = yield self.is_trial_user(user_id) + if is_trial: + return + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index e6848c70a0..10133f0a4a 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -123,8 +123,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def set_profile_displayname(self, user_localpart, new_displayname, batchnum): return self._simple_upsert( table="profiles", @@ -165,6 +163,8 @@ class ProfileStore(ProfileWorkerStore): lock=False # we can do this because user_id has a unique index ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8443bd4c1b..2743b52bad 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import types + +import six from canonicaljson import encode_canonical_json, json @@ -27,6 +28,11 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): @@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore): dataJson = r['data'] r['data'] = None try: - if isinstance(dataJson, types.BufferType): + if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") r['data'] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.message, + r['id'], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], types.BufferType): + if isinstance(r['pushkey'], db_binary_type): r['pushkey'] = str(r['pushkey']).decode("UTF8") return rows diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 50706519aa..65061f4c61 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(db_conn, hs) + + self.config = hs.config + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( @@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore): retcols=[ "name", "password_hash", "is_guest", "consent_version", "consent_server_notice_sent", - "appservice_id", + "appservice_id", "creation_ts", ], allow_none=True, desc="get_user_by_id", ) + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + defer.returnValue(False) + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + defer.returnValue(is_trial) + @cached() def get_user_by_access_token(self, token): """Get a user from the given access token. @@ -436,17 +462,44 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): - ret = yield self._simple_select_one( + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, + medium, address + ) + defer.returnValue(user_id) + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self._simple_select_one_txn( + txn, "user_threepids", { "medium": medium, "address": address }, - ['user_id'], True, 'get_user_id_by_threepid' + ['user_id'], True ) if ret: - defer.returnValue(ret['user_id']) - defer.returnValue(None) + return ret['user_id'] + return None def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( @@ -529,7 +582,7 @@ class RegistrationStore(RegistrationWorkerStore, def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - regex = re.compile("^@(\d+):") + regex = re.compile(r"^@(\d+):") found = set() diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3378fc77d1..61013b8919 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9b4e6d6aa8..0707f9a86a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -51,6 +51,12 @@ ProfileInfo = namedtuple( "ProfileInfo", ("avatar_url", "display_name") ) +# "members" points to a truncated list of (user_id, event_id) tuples for users of +# a given membership type, suitable for use in calculating heroes for a room. +# "count" points to the total numberr of users of a given membership type. +MemberSummary = namedtuple( + "MemberSummary", ("members", "count") +) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + sql = """ + SELECT m.user_id, m.membership, m.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + m.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..c0e66a697d --- /dev/null +++ b/synapse/storage/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 5623391f6e..158e9dbe7b 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -27,7 +27,7 @@ from ._base import SQLBaseStore # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..ef65929bb2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from collections import namedtuple from six import iteritems, itervalues from six.moves import range +import attr + from twisted.internet import defer from synapse.api.constants import EventTypes @@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 +@attr.s(slots=True) +class StateFilter(object): + """A filter used when querying for state. + + Attributes: + types (dict[str, set[str]|None]): Map from type to set of state keys (or + None). This specifies which state_keys for the given type to fetch + from the DB. If None then all events with that type are fetched. If + the set is empty then no events with that type are fetched. + include_others (bool): Whether to fetch events with types that do not + appear in `types`. + """ + + types = attr.ib() + include_others = attr.ib(default=False) + + def __attrs_post_init__(self): + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + self.types = { + k: v for k, v in iteritems(self.types) + if v is not None + } + + @staticmethod + def all(): + """Creates a filter that fetches everything. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=True) + + @staticmethod + def none(): + """Creates a filter that fetches nothing. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=False) + + @staticmethod + def from_types(types): + """Creates a filter that only fetches the given types + + Args: + types (Iterable[tuple[str, str|None]]): A list of type and state + keys to fetch. A state_key of None fetches everything for + that type + + Returns: + StateFilter + """ + type_dict = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) + + return StateFilter(types=type_dict) + + @staticmethod + def from_lazy_load_member_list(members): + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members (iterable[str]): Set of user IDs + + Returns: + StateFilter + """ + return StateFilter( + types={EventTypes.Member: set(members)}, + include_others=True, + ) + + def return_expanded(self): + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + StateFilter + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in iteritems(self.types) + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + else: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types={EventTypes.Member: self.types[EventTypes.Member]}, + include_others=True, + ) + + def make_sql_filter_clause(self): + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + tuple[str, list]: The SQL string (may be empty) and arguments. An + empty SQL string is returned when the filter matches everything + (i.e. is "full"). + """ + + where_clause = "" + where_args = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in iteritems(self.types): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % ( + ",".join(["?"] * len(self.types)), + ) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self): + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict): + """Returns the state filtered with by this StateFilter + + Args: + state (dict[tuple[str, str], Any]): The state map to filter + + Returns: + dict[tuple[str, str], Any]: The filtered state map + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in iteritems(state_dict): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self): + """Whether this filter fetches everything or not + + Returns: + bool + """ + return self.include_others and not self.types + + def has_wildcards(self): + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + bool + """ + + return ( + self.include_others + or any( + state_keys is None + for state_keys in itervalues(self.types) + ) + ) + + def concrete_types(self): + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + list[tuple[str,str]] + """ + return [ + (t, s) + for t, state_keys in iteritems(self.types) + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self): + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + tuple[StateFilter, StateFilter]: The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter({EventTypes.Member: state_keys}) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. @@ -60,8 +374,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -117,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + Args: room_id (str) - types (list[(Str, (Str|None))]): List of (type, state_key) tuples - which are used to filter the state fetched. `state_key` may be - None, which matches any `state_key` - filtered_types (list[Str]|None): List of types to apply the above filter to. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: - deferred: dict of (type, state_key) -> event + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. """ - include_other_types = False if filtered_types is None else True - def _get_filtered_current_state_ids_txn(txn): results = {} - sql = """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? %s""" - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - - if include_other_types: - unique_types = set(filtered_types) - clause_to_args.append( - ( - "AND type <> ? " * len(unique_types), - list(unique_types) - ) - ) - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - for where_clause, where_args in clause_to_args: - args = [room_id] - args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results return self.runInteraction( @@ -220,7 +549,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -235,7 +574,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -251,7 +590,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -275,18 +616,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. Args: groups(list[int]): list of state group IDs to query - types (Iterable[str, str|None]|None): list of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. If None, all types are returned. - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -294,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, state_filter=StateFilter.all(), ): results = {group: {} for group in groups} - if types is not None: - types = list(set(types)) # deduplicate types list + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -322,69 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT type, state_key, last_value(event_id) OVER ( + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( PARTITION BY type, state_key ORDER BY state_group ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS event_id FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state ) - %s - """) - - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types is not None: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] + """ - for where_clause, where_args in clause_to_args: - for group in groups: - args = [group] - args.extend(where_args) + for group in groups: + args = [group] + args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: - where_args = [] - where_clauses = [] - wildcard_types = False - if types is not None: - for typ in types: - if typ[1] is None: - where_clauses.append("(type = ?)") - where_args.append(typ[0]) - wildcard_types = True - else: - where_clauses.append("(type = ? AND state_key = ?)") - where_args.extend([typ[0], typ[1]]) - - where_clause = "AND (%s)" % (" OR ".join(where_clauses)) - else: - where_clause = "" + max_entries_returned = state_filter.max_entries_returned() # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) @@ -398,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # without the right indices (which we can't add until # after we finish deduping state, which requires this func) args = [next_group] - if types: - args.extend(where_args) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? %s" % (where_clause,), + " WHERE state_group = ? " + where_clause, args ) results[group].update( @@ -419,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - types is not None and - not wildcard_types and - len(results[group]) == len(types) + max_entries_returned is not None and + len(results[group]) == max_entries_returned ): break @@ -436,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types, filtered_types=None): + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state - dicts for each event. The state dicts will only have the type/state_keys - that are in the `types` list. + dicts for each event. Args: event_ids (list[string]) - types (list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -459,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -478,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from event_id -> (type, state_key) -> event_id @@ -501,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) event_to_state = { event_id: group_to_state[group] @@ -511,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None, filtered_types=None): + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_ids_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -580,179 +865,207 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup - types(list[str, str|None]): List of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool indicating if we successfully retrieved all requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) - type_to_key = {} + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False - for typ, state_key in types: - key = (typ, state_key) - - if ( - state_key is None or - (filtered_types is not None and typ not in filtered_types) - ): - type_to_key[typ] = None - # we mark the type as missing from the cache because - # when the cache was populated it might have been done with a - # restricted set of state_keys, so the wildcard will not work - # and the cache may be incomplete. - missing_types = True - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): if key not in state_dict_ids and key not in known_absent: missing_types = True + break - sentinel = object() - - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return filtered_types is not None and typ not in filtered_types - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True - return False - - got_all = is_all - if not got_all: - # the cache is incomplete. We may still have got all the results we need, if - # we don't have any wildcards in the match list. - if not missing_types and filtered_types is None: - got_all = True - - return { - k: v for k, v in iteritems(state_dict_ids) - if include(k[0], k[1]) - }, got_all - - def _get_all_state_from_cache(self, group): - """Checks if group is in cache. See `_get_state_for_groups` + return state_filter.filter_state(state_dict_ids), not missing_types - Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool - indicating if we successfully retrieved all requests state from the - cache, if False we need to query the DB for the missing state. + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key Args: - group: The state group to lookup + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) - return state_dict_ids, is_all + member_filter, non_member_filter = state_filter.get_member_split() - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None, filtered_types=None): + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, + state_filter=non_member_filter, + ) + ) + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, + state_filter=member_filter, + ) + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + defer.returnValue(state) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), + state_filter=db_state_filter, + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + defer.returnValue(state) + + def _get_state_for_groups_using_cache( + self, groups, cache, state_filter, + ): """Gets the state at each of a list of state groups, optionally - filtering by type/state_key + filtering by type/state_key, querying from a specific cache. Args: groups (iterable[int]): list of state groups for which we want to get the state. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. """ - if types: - types = frozenset(types) results = {} - missing_groups = [] - if types is not None: - for group in set(groups): - state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types - ) - results[group] = state_dict_ids + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids - if not got_all: - missing_groups.append(group) - else: - for group in set(groups): - state_dict_ids, got_all = self._get_all_state_from_cache( - group - ) + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups - results[group] = state_dict_ids + def _insert_into_cache(self, group_to_state_dict, state_filter, + cache_seq_num_members, cache_seq_num_non_members): + """Inserts results from querying the database into the relevant cache. - if not got_all: - missing_groups.append(group) + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ - if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) - # the DictionaryCache knows if it has *all* the state, but - # does not know if it has all of the keys of a particular type, - # which makes wildcard lookups expensive unless we have a complete - # cache. Hence, if we are doing a wildcard lookup, populate the - # cache fully so that we can do an efficient lookup next time. + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() - if filtered_types or (types and any(k is None for (t, k) in types)): - types_to_fetch = None - else: - types_to_fetch = types + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() - group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch - ) + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict = results[group] - - # update the result, filtering by `types`. - if types: - for k, v in iteritems(group_state_dict): - (typ, _) = k - if ( - (k in types or (typ, None) in types) or - (filtered_types and typ not in filtered_types) - ): - state_dict[k] = v + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v else: - state_dict.update(group_state_dict) - - # update the cache with all the things we fetched from the - # database. - self._state_group_cache.update( - cache_seq_num, - key=group, - value=group_state_dict, - fetched_keys=types_to_fetch, - ) + state_dict_non_members[k] = v - defer.returnValue(results) + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) def store_state_group(self, event_id, room_id, prev_group, delta_ids, current_state_ids): @@ -847,15 +1160,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group @@ -1043,12 +1374,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], types=None + txn, [prev_group], ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], types=None + txn, [state_group], ) curr_state = curr_state[state_group] diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 4c296d72c0..d6cfdba519 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -630,7 +630,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events""" + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ def get_all_new_events_stream_txn(txn): sql = ( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 428e7fa36e..d8bf953ec0 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,19 +18,19 @@ from collections import namedtuple import six -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.descriptors import cached +from synapse.util.caches.expiringcache import ExpiringCache -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview @@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple( ) ) +SENTINEL = object() + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. @@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore): self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -95,7 +103,8 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], json.loads(str(result["response_json"])) + return result["response_code"], db_to_json(result["response_json"]) + else: return None @@ -155,7 +164,7 @@ class TransactionStore(SQLBaseStore): """ pass - @cached(max_entries=10000) + @defer.inlineCallbacks def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -166,10 +175,20 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a dict for the retry scheme """ - return self.runInteraction( + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + defer.returnValue(result) + + result = yield self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + defer.returnValue(result) + def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, @@ -197,8 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # XXX: we could chose to not bother persisting this if our cache thinks - # this is a NOOP + self._destination_retry_cache.pop(destination, None) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, @@ -211,10 +229,6 @@ class TransactionStore(SQLBaseStore): retry_last_ts, retry_interval): self.database_engine.lock_table(txn, "destinations") - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 680ea928c7..0ae7e2ef3b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import re from itertools import islice import attr @@ -68,7 +69,10 @@ class Clock(object): """ call = task.LoopingCall(f) call.clock = self._reactor - call.start(msec / 1000.0, now=False) + d = call.start(msec / 1000.0, now=False) + d.addErrback( + log_failure, "Looping call died", consumeErrors=False, + ) return call def call_later(self, delay, callback, *args, **kwargs): @@ -109,3 +113,53 @@ def batch_iter(iterable, size): sourceiter = iter(iterable) # call islice until it returns an empty tuple return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +def log_failure(failure, msg, consumeErrors=True): + """Creates a function suitable for passing to `Deferred.addErrback` that + logs any failures that occur. + + Args: + msg (str): Message to log + consumeErrors (bool): If true consumes the failure, otherwise passes + on down the callback chain + + Returns: + func(Failure) + """ + + logger.error( + msg, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject() + ) + ) + + if not consumeErrors: + return failure + + +def glob_to_regex(glob): + """Converts a glob to a compiled regex object. + + The regex is anchored at the beginning and end of the string. + + Args: + glob (str) + + Returns: + re.RegexObject + """ + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + + # \A anchors at start of string, \Z at end of string + return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 9b3f2f4b96..ec7b2c9672 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -374,29 +374,25 @@ class ReadWriteLock(object): defer.returnValue(_ctx_manager()) -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise defer.TimeoutError(timeout, "Deferred") + return value -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. +def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """The in built twisted `Deferred.addTimeout` fails to time out deferreds + that have a canceller that throws exceptions. This method creates a new + deferred that wraps and times out the given deferred, correctly handling + the case where the given deferred's canceller throws. - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - + deferred (Deferred) + timeout (float): Timeout in seconds + reactor (twisted.internet.reactor): The twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. @@ -406,13 +402,26 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): the timeout. The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. + CancelledError Failure into a defer.TimeoutError. + + Returns: + Deferred """ + + new_d = defer.Deferred() + timed_out = [False] def time_it_out(): timed_out[0] = True - deferred.cancel() + + try: + deferred.cancel() + except: # noqa: E722, if we throw any exception it'll break time outs + logger.exception("Canceller failed during timeout") + + if not new_d.called: + new_d.errback(defer.TimeoutError(timeout, "Deferred")) delayed_call = reactor.callLater(timeout, time_it_out) @@ -432,9 +441,14 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addBoth(cancel_timeout) + def success_cb(val): + if not new_d.called: + new_d.callback(val) -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value + def failure_cb(val): + if not new_d.called: + new_d.errback(val) + + deferred.addCallbacks(success_cb, failure_cb) + + return new_d diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 7b065b195e..f37d5bec08 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import six @@ -20,6 +21,8 @@ from six.moves import intern from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily +logger = logging.getLogger(__name__) + CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) @@ -76,16 +79,20 @@ def register_cache(cache_type, cache_name, cache): return [] def collect(self): - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) + try: + if cache_type == "response_cache": + response_cache_size.labels(cache_name).set(len(cache)) + response_cache_hits.labels(cache_name).set(self.hits) + response_cache_evicted.labels(cache_name).set(self.evicted_size) + response_cache_total.labels(cache_name).set(self.hits + self.misses) + else: + cache_size.labels(cache_name).set(len(cache)) + cache_hits.labels(cache_name).set(self.hits) + cache_evicted.labels(cache_name).set(self.evicted_size) + cache_total.labels(cache_name).set(self.hits + self.misses) + except Exception as e: + logger.warn("Error calculating metrics for %s: %s", cache_name, e) + raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ce85b2ae11..f369780277 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -16,12 +16,17 @@ import logging from collections import OrderedDict +from six import iteritems, itervalues + from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache logger = logging.getLogger(__name__) +SENTINEL = object() + + class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, reset_expiry_on_get=False, iterable=False): @@ -54,11 +59,8 @@ class ExpiringCache(object): self.iterable = iterable - self._size_estimate = 0 - self.metrics = register_cache("expiring", cache_name, self) - def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire return @@ -75,16 +77,11 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - if self.iterable: - self._size_estimate += len(value) - # Evict if there are now too many items while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - removed_len = len(value.value) - self.metrics.inc_evictions(removed_len) - self._size_estimate -= removed_len + self.metrics.inc_evictions(len(value.value)) else: self.metrics.inc_evictions() @@ -101,6 +98,21 @@ class ExpiringCache(object): return entry.value + def pop(self, key, default=SENTINEL): + """Removes and returns the value with the given key from the cache. + + If the key isn't in the cache then `default` will be returned if + specified, otherwise `KeyError` will get raised. + + Identical functionality to `dict.pop(..)`. + """ + + value = self._cache.pop(key, default) + if value is SENTINEL: + raise KeyError(key) + + return value + def __contains__(self, key): return key in self._cache @@ -128,14 +140,16 @@ class ExpiringCache(object): keys_to_delete = set() - for key, cache_entry in self._cache.items(): + for key, cache_entry in iteritems(self._cache): if now - cache_entry.time > self._expiry_ms: keys_to_delete.add(key) for k in keys_to_delete: value = self._cache.pop(k) if self.iterable: - self._size_estimate -= len(value.value) + self.metrics.inc_evictions(len(value.value)) + else: + self.metrics.inc_evictions() logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -144,12 +158,14 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return self._size_estimate + return sum(len(entry.value) for entry in itervalues(self._cache)) else: return len(self._cache) class _CacheEntry(object): + __slots__ = ["time", "value"] + def __init__(self, time, value): self.time = time self.value = value diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index f2bde74dc5..625aedc940 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -15,6 +15,8 @@ import logging +from six import integer_types + from sortedcontainers import SortedDict from synapse.util import caches @@ -47,7 +49,7 @@ class StreamChangeCache(object): def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) is int or type(stream_pos) is long + assert type(stream_pos) in integer_types if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a0c2d37610..4c6e92beb8 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works. import logging import threading -from twisted.internet import defer +from twisted.internet import defer, threads logger = logging.getLogger(__name__) @@ -200,7 +200,7 @@ class LoggingContext(object): sentinel = Sentinel() - def __init__(self, name=None, parent_context=None): + def __init__(self, name=None, parent_context=None, request=None): self.previous_context = LoggingContext.current_context() self.name = name @@ -218,6 +218,13 @@ class LoggingContext(object): self.parent_context = parent_context + if self.parent_context is not None: + self.parent_context.copy_to(self) + + if request is not None: + # the request param overrides the request from the parent context + self.request = request + def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -256,9 +263,6 @@ class LoggingContext(object): ) self.alive = True - if self.parent_context is not None: - self.parent_context.copy_to(self) - return self def __exit__(self, type, value, traceback): @@ -439,6 +443,35 @@ class PreserveLoggingContext(object): ) +def nested_logging_context(suffix, parent_context=None): + """Creates a new logging context as a child of another. + + The nested logging context will have a 'request' made up of the parent context's + request, plus the given suffix. + + CPU/db usage stats will be added to the parent context's on exit. + + Normal usage looks like: + + with nested_logging_context(suffix): + # ... do stuff + + Args: + suffix (str): suffix to add to the parent context's 'request'. + parent_context (LoggingContext|None): parent context. Will use the current context + if None. + + Returns: + LoggingContext: new logging context. + """ + if parent_context is None: + parent_context = LoggingContext.current_context() + return LoggingContext( + parent_context=parent_context, + request=parent_context.request + "-" + suffix, + ) + + def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" def g(*args, **kwargs): @@ -529,58 +562,76 @@ def _set_context_cb(result, context): return result -# modules to ignore in `logcontext_tracer` -_to_ignore = [ - "synapse.util.logcontext", - "synapse.http.server", - "synapse.storage._base", - "synapse.util.async_helpers", -] +def defer_to_thread(reactor, f, *args, **kwargs): + """ + Calls the function `f` using a thread from the reactor's default threadpool and + returns the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked, and whose threadpool we should use for the + function. -def logcontext_tracer(frame, event, arg): - """A tracer that logs whenever a logcontext "unexpectedly" changes within - a function. Probably inaccurate. + Normally this will be hs.get_reactor(). - Use by calling `sys.settrace(logcontext_tracer)` in the main thread. + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. """ - if event == 'call': - name = frame.f_globals["__name__"] - if name.startswith("synapse"): - if name == "synapse.util.logcontext": - if frame.f_code.co_name in ["__enter__", "__exit__"]: - tracer = frame.f_back.f_trace - if tracer: - tracer.just_changed = True - - tracer = frame.f_trace - if tracer: - return tracer - - if not any(name.startswith(ig) for ig in _to_ignore): - return LineTracer() - - -class LineTracer(object): - __slots__ = ["context", "just_changed"] - - def __init__(self): - self.context = LoggingContext.current_context() - self.just_changed = False - - def __call__(self, frame, event, arg): - if event in 'line': - if self.just_changed: - self.context = LoggingContext.current_context() - self.just_changed = False - else: - c = LoggingContext.current_context() - if c != self.context: - logger.info( - "Context changed! %s -> %s, %s, %s", - self.context, c, - frame.f_code.co_filename, frame.f_lineno - ) - self.context = c + return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) - return self + +def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): + """ + A wrapper for twisted.internet.threads.deferToThreadpool, which handles + logcontexts correctly. + + Calls the function `f` using a thread from the given threadpool and returns + the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked. Normally this will be hs.get_reactor(). + + threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for + running `f`. Normally this will be hs.get_reactor().getThreadPool(). + + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. + """ + logcontext = LoggingContext.current_context() + + def g(): + with LoggingContext(parent_context=logcontext): + return f(*args, **kwargs) + + return make_deferred_yieldable( + threads.deferToThreadPool(reactor, threadpool, g) + ) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 14be3c7396..9cb7e9c9ab 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -19,22 +19,40 @@ from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal PUBLIC_KEY = ( - "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" - "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS" - "kbh/C+BR3utDS555mV" + "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" + "XALqeK+7385NlLja3DE/DO9mGhnd9+bAy39EKT3sTV6+WXQ4yD0TvEEyUEMtjWkSEm6U32+C" + "DaS3TW/vPBUMeJQwq+Ydcif1UlnpXrDDTamD0AU9VaEvHq+3HAkipqn0TGpKON6aqk4vauDx" + "oXSsV5TXBVrxP/y7HpMOpU4GUWsaaacBTKKNnUaQB4UflvydaPJUuwdaCUJGTMjbhWrjVfK+" + "jslseSPxU6XvrkZMyCr4znxvuDxjMk1RGIdO7v+rbBMLEgqtSMNqJbYeVCnj2CFgc3fcTcld" + "X2uOJDrJb/WRlHulthCh" ) PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- -MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW -4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw -vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb -Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1 -xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8 -PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2 -gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu -DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML -pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP -EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg== +MIIEpQIBAAKCAQEAx4RgE2luCoRNt/u56x+Ixcd8i6vTo2hLOVwC6nivu9/OTZS4 +2twxPwzvZhoZ3ffmwMt/RCk97E1evll0OMg9E7xBMlBDLY1pEhJulN9vgg2kt01v +7zwVDHiUMKvmHXIn9VJZ6V6ww02pg9AFPVWhLx6vtxwJIqap9ExqSjjemqpOL2rg +8aF0rFeU1wVa8T/8ux6TDqVOBlFrGmmnAUyijZ1GkAeFH5b8nWjyVLsHWglCRkzI +24Vq41Xyvo7JbHkj8VOl765GTMgq+M58b7g8YzJNURiHTu7/q2wTCxIKrUjDaiW2 +HlQp49ghYHN33E3JXV9rjiQ6yW/1kZR7pbYQoQIDAQABAoIBAQC8KJ0q8Wzzwh5B +esa1dQHZ8+4DEsL/Amae66VcVwD0X3cCN1W2IZ7X5W0Ij2kBqr8V51RYhcR+S+Ek +BtzSiBUBvbKGrqcMGKaUgomDIMzai99hd0gvCCyZnEW1OQhFkNkaRNXCfqiZJ27M +fqvSUiU2eOwh9fCvmxoA6Of8o3FbzcJ+1GMcobWRllDtLmj6lgVbDzuA+0jC5daB +9Tj1pBzu3wn3ufxiS+gBnJ+7NcXH3E73lqCcPa2ufbZ1haxfiGCnRIhFXuQDgxFX +vKdEfDgtvas6r1ahGbc+b/q8E8fZT7cABuIU4yfOORK+MhpyWbvoyyzuVGKj3PKt +KSPJu5CZAoGBAOkoJfAVyYteqKcmGTanGqQnAY43CaYf6GdSPX/jg+JmKZg0zqMC +jWZUtPb93i+jnOInbrnuHOiHAxI8wmhEPed28H2lC/LU8PzlqFkZXKFZ4vLOhhRB +/HeHCFIDosPFlohWi3b+GAjD7sXgnIuGmnXWe2ea/TS3yersifDEoKKjAoGBANsQ +gJX2cJv1c3jhdgcs8vAt5zIOKcCLTOr/QPmVf/kxjNgndswcKHwsxE/voTO9q+TF +v/6yCSTxAdjuKz1oIYWgi/dZo82bBKWxNRpgrGviU3/zwxiHlyIXUhzQu78q3VS/ +7S1XVbc7qMV++XkYKHPVD+nVG/gGzFxumX7MLXfrAoGBAJit9cn2OnjNj9uFE1W6 +r7N254ndeLAUjPe73xH0RtTm2a4WRopwjW/JYIetTuYbWgyujc+robqTTuuOZjAp +H/CG7o0Ym251CypQqaFO/l2aowclPp/dZhpPjp9GSjuxFBZLtiBB3DNBOwbRQzIK +/vLTdRQvZkgzYkI4i0vjNt3JAoGBANP8HSKBLymMlShlrSx2b8TB9tc2Y2riohVJ +2ttqs0M2kt/dGJWdrgOz4mikL+983Olt/0P9juHDoxEEMK2kpcPEv40lnmBpYU7h +s8yJvnBLvJe2EJYdJ8AipyAhUX1FgpbvfxmASP8eaUxsegeXvBWTGWojAoS6N2o+ +0KSl+l3vAoGAFqm0gO9f/Q1Se60YQd4l2PZeMnJFv0slpgHHUwegmd6wJhOD7zJ1 +CkZcXwiv7Nog7AI9qKJEUXLjoqL+vJskBzSOqU3tcd670YQMi1aXSXJqYE202K7o +EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" @@ -52,6 +70,8 @@ def manhole(username, password, globals): Returns: twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ + if not isinstance(password, bytes): + password = password.encode('ascii') checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( **{username: password} @@ -64,7 +84,7 @@ def manhole(username, password, globals): ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) return factory diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 97f1267380..4b4ac5f6c7 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -20,6 +20,7 @@ from prometheus_client import Counter from twisted.internet import defer +from synapse.metrics import InFlightGauge from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) @@ -45,6 +46,13 @@ block_db_txn_duration = Counter( block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) +# Tracks the number of blocks currently active +in_flight = InFlightGauge( + "synapse_util_metrics_block_in_flight", "", + labels=["block_name"], + sub_metrics=["real_time_max", "real_time_sum"], +) + def measure_func(name): def wrapper(func): @@ -82,10 +90,14 @@ class Measure(object): self.start_usage = self.start_context.get_resource_usage() + in_flight.register((self.name,), self._update_in_flight) + def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: return + in_flight.unregister((self.name,), self._update_in_flight) + duration = self.clock.time() - self.start block_counter.labels(self.name).inc() @@ -120,3 +132,13 @@ class Measure(object): if self.created_context: self.start_context.__exit__(exc_type, exc_val, exc_tb) + + def _update_in_flight(self, metrics): + """Gets called when processing in flight metrics + """ + duration = self.clock.time() - self.start + + metrics.real_time_max = max(metrics.real_time_max, duration) + metrics.real_time_sum += duration + + # TODO: Add other in flight metrics. diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 8a3a06fd74..26cce7d197 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -188,7 +188,7 @@ class RetryDestinationLimiter(object): else: self.retry_interval = self.min_retry_interval - logger.debug( + logger.info( "Connection to %s was unsuccessful (%s(%s)); backoff now %i", self.destination, exc_type, exc_val, self.retry_interval ) diff --git a/synapse/visibility.py b/synapse/visibility.py index d4680863d3..0281a7c919 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, ) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), - types=types, + state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( @@ -219,7 +220,7 @@ def filter_events_for_server(store, server_name, events): # Whatever else we do, we need to check for senders which have requested # erasure of their data. erased_senders = yield store.are_users_erased( - e.sender for e in events, + (e.sender for e in events), ) def redact_disallowed(event, state): @@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events): # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), + state_filter=StateFilter.from_types( + types=((EventTypes.RoomHistoryVisibility, ""),), ) ) @@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events): # of the history vis and membership state at those events. event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), + state_filter=StateFilter.from_types( + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ), ) ) @@ -324,14 +327,13 @@ def filter_events_for_server(store, server_name, events): # server's domain. # # event_to_state_ids contains lots of duplicates, so it turns out to be - # cheaper to build a complete set of unique - # ((type, state_key), event_id) tuples, and then filter out the ones we - # don't want. + # cheaper to build a complete event_id => (type, state_key) dict, and then + # filter out the ones we don't want # - state_key_to_event_id_set = { - e + event_id_to_state_key = { + event_id: key for key_to_eid in itervalues(event_to_state_ids) - for e in key_to_eid.items() + for key, event_id in iteritems(key_to_eid) } def include(typ, state_key): @@ -346,7 +348,7 @@ def filter_events_for_server(store, server_name, events): event_map = yield store.get_events([ e_id - for key, e_id in state_key_to_event_id_set + for e_id, key in iteritems(event_id_to_state_key) if include(key[0], key[1]) ]) |