diff options
Diffstat (limited to 'synapse')
103 files changed, 2769 insertions, 916 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index a14d578e36..e62901b761 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,4 +17,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.2" +__version__ = "0.33.3" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5bbbe8e2e7..6502a6be7b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,7 @@ from twisted.internet import defer import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -211,9 +211,9 @@ class Auth(object): user_agent = request.requestHeaders.getRawHeaders( b"User-Agent", default=[b""] - )[0] + )[0].decode('ascii', 'surrogateescape') if user and access_token and ip_addr: - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -682,7 +682,7 @@ class Auth(object): Returns: bool: False if no access_token was given, True otherwise. """ - query_params = request.args.get("access_token") + query_params = request.args.get(b"access_token") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -698,7 +698,7 @@ class Auth(object): 401 since some of the old clients depended on auth errors returning 403. Returns: - str: The access_token + unicode: The access_token Raises: AuthError: If there isn't an access_token in the request. """ @@ -720,9 +720,9 @@ class Auth(object): "Too many Authorization headers.", errcode=Codes.MISSING_TOKEN, ) - parts = auth_headers[0].split(" ") - if parts[0] == "Bearer" and len(parts) == 2: - return parts[1] + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + return parts[1].decode('ascii') else: raise AuthError( token_not_found_http_status, @@ -738,7 +738,7 @@ class Auth(object): errcode=Codes.MISSING_TOKEN ) - return query_params[0] + return query_params[0].decode('ascii') @defer.inlineCallbacks def check_in_room_or_world_readable(self, room_id, user_id): @@ -773,3 +773,36 @@ class Auth(object): raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN ) + + @defer.inlineCallbacks + def check_auth_blocking(self, user_id=None): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + + Args: + user_id(str|None): If present, checks for presence against existing + MAU cohort + """ + if self.hs.config.hs_disabled: + raise ResourceLimitError( + 403, self.hs.config.hs_disabled_message, + errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_uri=self.hs.config.admin_uri, + limit_type=self.hs.config.hs_disabled_limit_type + ) + if self.hs.config.limit_usage_by_mau is True: + # If the user is already part of the MAU cohort + if user_id: + timestamp = yield self.store.user_last_seen_monthly_active(user_id) + if timestamp: + return + # Else if there is no room in the MAU bucket, bail + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self.hs.config.max_mau_value: + raise ResourceLimitError( + 403, "Monthly Active User Limit Exceeded", + + admin_uri=self.hs.config.admin_uri, + errcode=Codes.RESOURCE_LIMIT_EXCEED, + limit_type="monthly_active_user" + ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 4df930c8d1..b0da506f6d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,3 +95,11 @@ class RoomCreationPreset(object): class ThirdPartyEntityKind(object): USER = "user" LOCATION = "location" + + +# the version we will give rooms which are created on this server +DEFAULT_ROOM_VERSION = "1" + +# vdh-test-version is a placeholder to get room versioning support working and tested +# until we have a working v2. +KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b41d595059..e26001ab12 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -55,7 +56,9 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" - MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" + RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED" + UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" + INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" class CodeMessageException(RuntimeError): @@ -228,6 +231,30 @@ class AuthError(SynapseError): super(AuthError, self).__init__(*args, **kwargs) +class ResourceLimitError(SynapseError): + """ + Any error raised when there is a problem with resource usage. + For instance, the monthly active user limit for the server has been exceeded + """ + def __init__( + self, code, msg, + errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_uri=None, + limit_type=None, + ): + self.admin_uri = admin_uri + self.limit_type = limit_type + super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + admin_uri=self.admin_uri, + limit_type=self.limit_type + ) + + class EventSizeError(SynapseError): """An error raised when an event is too big.""" @@ -285,6 +312,27 @@ class LimitExceededError(SynapseError): ) +class IncompatibleRoomVersionError(SynapseError): + """A server is trying to join a room whose version it does not support.""" + + def __init__(self, room_version): + super(IncompatibleRoomVersionError, self).__init__( + code=400, + msg="Your homeserver does not support the features required to " + "join this room", + errcode=Codes.INCOMPATIBLE_ROOM_VERSION, + ) + + self._room_version = room_version + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + room_version=self._room_version, + ) + + def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 06cc8d90b8..3bb5b3da37 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -72,7 +72,7 @@ class Ratelimiter(object): return allowed, time_allowed def prune_message_counts(self, time_now_s): - for user_id in self.message_counts.keys(): + for user_id in list(self.message_counts.keys()): message_count, time_start, msg_rate_hz = ( self.message_counts[user_id] ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 391bd14c5c..7c866e246a 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port): logger.info("Metrics now reporting on %s:%d", host, port) -def listen_tcp(bind_addresses, port, factory, backlog=50): +def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): """ Create a TCP socket for a port and several addresses """ @@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50): check_bind_error(e, address, bind_addresses) -def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): +def listen_ssl( + bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 +): """ Create an SSL socket for a port and several addresses """ diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 9a37384fb7..3348a8ec6d 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index e2c91123db..ab79a45646 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import ( JoinedRoomMemberListRestServlet, @@ -66,7 +66,7 @@ class ClientReaderSlavedStore( DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, - TransactionStore, + SlavedTransactionStore, SlavedClientIpStore, BaseSlavedStore, ): @@ -168,11 +168,13 @@ def start(config_options): database_engine = create_engine(config.database_config) tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ss = ClientReaderServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 374f115644..03d39968a8 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -43,7 +43,7 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import ( JoinRoomAliasServlet, @@ -63,7 +63,7 @@ logger = logging.getLogger("synapse.app.event_creator") class EventCreatorSlavedStore( DirectoryStore, - TransactionStore, + SlavedTransactionStore, SlavedProfileStore, SlavedAccountDataStore, SlavedPusherStore, @@ -174,11 +174,13 @@ def start(config_options): database_engine = create_engine(config.database_config) tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ss = EventCreatorServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7af00b8bcf..7d8105778d 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -32,11 +32,17 @@ from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.profile import SlavedProfileStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -49,11 +55,17 @@ logger = logging.getLogger("synapse.app.federation_reader") class FederationReaderSlavedStore( + SlavedAccountDataStore, + SlavedProfileStore, + SlavedApplicationServiceStore, + SlavedPusherStore, + SlavedPushRuleStore, + SlavedReceiptsStore, SlavedEventStore, SlavedKeyStore, RoomStore, DirectoryStore, - TransactionStore, + SlavedTransactionStore, BaseSlavedStore, ): pass @@ -143,11 +155,13 @@ def start(config_options): database_engine = create_engine(config.database_config) tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ss = FederationReaderServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 18469013fa..d59007099b 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole @@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender") class FederationSenderSlaveStore( - SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, + SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, ): def __init__(self, db_conn, hs): @@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(FederationSenderReplicationHandler, self).on_rdata( + yield super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) @@ -186,11 +187,13 @@ def start(config_options): config.send_federation = True tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ps = FederationSenderServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index b5f78f4640..8d484c1cd4 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.frontend_proxy") +class PresenceStatusStubServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") + + def __init__(self, hs): + super(PresenceStatusStubServlet, self).__init__(hs) + self.http_client = hs.get_simple_http_client() + self.auth = hs.get_auth() + self.main_uri = hs.config.worker_main_http_uri + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } + result = yield self.http_client.get_json( + self.main_uri + request.uri, + headers=headers, + ) + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + yield self.auth.get_user_by_req(request) + defer.returnValue((200, {})) + + class KeyUploadServlet(RestServlet): PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") @@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) KeyUploadServlet(self).register(resource) + + # If presence is disabled, use the stub servlet that does + # not allow sending presence + if not self.config.use_presence: + PresenceStatusStubServlet(self).register(resource) + resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, @@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), + reactor=self.get_reactor() ) logger.info("Synapse client reader now listening on port %d", port) @@ -208,11 +245,13 @@ def start(config_options): database_engine = create_engine(config.database_config) tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ss = FrontendProxyServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fba51c26e8..005921dcf7 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer): # Gauges to expose monthly active user control metrics -current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") -max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") +current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") def setup(config_options): @@ -338,6 +338,7 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection @@ -346,6 +347,7 @@ def setup(config_options): config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, @@ -519,17 +521,27 @@ def run(hs): # table will decrease clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + # monthly active user limiting functionality + clock.looping_call( + hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastore().reap_monthly_active_users() + @defer.inlineCallbacks def generate_monthly_active_users(): count = 0 if hs.config.limit_usage_by_mau: - count = yield hs.get_datastore().count_monthly_users() + count = yield hs.get_datastore().get_monthly_active_count() current_mau_gauge.set(float(count)) - max_mau_value_gauge.set(float(hs.config.max_mau_value)) + max_mau_gauge.set(float(hs.config.max_mau_value)) + hs.get_datastore().initialise_reserved_users( + hs.config.mau_limits_reserved_threepids + ) generate_monthly_active_users() if hs.config.limit_usage_by_mau: clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + # End of monthly active user settings if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 749bbf37d0..fd1f6cbf7e 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer @@ -52,7 +52,7 @@ class MediaRepositorySlavedStore( SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedClientIpStore, - TransactionStore, + SlavedTransactionStore, BaseSlavedStore, MediaRepositoryStore, ): @@ -155,11 +155,13 @@ def start(config_options): database_engine = create_engine(config.database_config) tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ss = MediaRepositoryServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 9295a51d5b..a4fc7e91fa 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks @@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( + self.pusher_pool.on_new_notifications( token, token, ) elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( + self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) ) except Exception: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index e201f18efd..27e1998660 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -114,7 +114,10 @@ class SynchrotronPresence(object): logger.info("Presence process_id is %r", self.process_id) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_user_sync( + user_id, is_syncing, last_sync_ms + ) def mark_as_coming_online(self, user_id): """A user has started syncing. Send a UserSync to the master, unless they @@ -211,10 +214,13 @@ class SynchrotronPresence(object): yield self.notify_from_replication(states, stream_id) def get_currently_syncing_users(self): - return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) - if count > 0 - ] + if self.hs.config.use_presence: + return [ + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + if count > 0 + ] + else: + return set() class SynchrotronTyping(object): @@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 637a89530a..1388a42b59 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(UserDirectoryReplicationHandler, self).on_rdata( + yield super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == "current_state_deltas": @@ -214,11 +215,13 @@ def start(config_options): config.update_user_directory = True tls_server_context_factory = context_factory.ServerContextFactory(config) + tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config) ps = UserDirectoryServer( config.server_name, db_config=config.database_config, tls_server_context_factory=tls_server_context_factory, + tls_client_options_factory=tls_client_options_factory, config=config, version_string="Synapse/" + get_version_string(synapse), database_engine=database_engine, diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a87b11a1df..3f187adfc8 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False): if log_file: # TODO: Customisable file size / backup count handler = logging.handlers.RotatingFileHandler( - log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 + log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, + encoding='utf8' ) def sighup(signum, stack): @@ -193,9 +194,8 @@ def setup_logging(config, use_worker_options=False): def sighup(signum, stack): # it might be better to use a file watcher or something for this. - logging.info("Reloading log config from %s due to SIGHUP", - log_config) load_log_config() + logging.info("Reloaded log config from %s due to SIGHUP", log_config) load_log_config() diff --git a/synapse/config/server.py b/synapse/config/server.py index 6a471a0a5e..68a612e594 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -49,6 +49,9 @@ class ServerConfig(Config): # "disable" federation self.send_federation = config.get("send_federation", True) + # Whether to enable user presence. + self.use_presence = config.get("use_presence", True) + # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) @@ -69,12 +72,24 @@ class ServerConfig(Config): # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) + self.max_mau_value = 0 if self.limit_usage_by_mau: self.max_mau_value = config.get( "max_mau_value", 0, ) - else: - self.max_mau_value = 0 + self.mau_limits_reserved_threepids = config.get( + "mau_limit_reserved_threepids", [] + ) + + # Options to disable HS + self.hs_disabled = config.get("hs_disabled", False) + self.hs_disabled_message = config.get("hs_disabled_message", "") + self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") + + # Admin uri to direct users at should their instance become blocked + # due to resource constraints + self.admin_uri = config.get("admin_uri", None) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( @@ -238,6 +253,9 @@ class ServerConfig(Config): # hard limit. soft_file_limit: 0 + # Set to false to disable presence tracking on this homeserver. + use_presence: true + # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] @@ -329,6 +347,32 @@ class ServerConfig(Config): # - port: 9000 # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + + + # Homeserver blocking + # + # How to reach the server admin, used in ResourceLimitError + # admin_uri: 'mailto:admin@server.com' + # + # Global block config + # + # hs_disabled: False + # hs_disabled_message: 'Human readable reason for why the HS is blocked' + # hs_disabled_limit_type: 'error code(str), to help clients decode reason' + # + # Monthly Active User Blocking + # + # Enables monthly active user checking + # limit_usage_by_mau: False + # max_mau_value: 50 + # + # Sometimes the server admin will want to ensure certain accounts are + # never blocked by mau checking. These accounts are specified here. + # + # mau_limit_reserved_threepids: + # - medium: 'email' + # address: 'reserved_user@example.com' + """ % locals() def read_arguments(self, args): diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index a1e1d0d33a..1a391adec1 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -11,19 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging +from zope.interface import implementer + from OpenSSL import SSL, crypto -from twisted.internet import ssl from twisted.internet._sslverify import _defaultCurveName +from twisted.internet.interfaces import IOpenSSLClientConnectionCreator +from twisted.internet.ssl import CertificateOptions, ContextFactory +from twisted.python.failure import Failure logger = logging.getLogger(__name__) -class ServerContextFactory(ssl.ContextFactory): +class ServerContextFactory(ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming - connections and to make connections to remote servers.""" + connections.""" def __init__(self, config): self._context = SSL.Context(SSL.SSLv23_METHOD) @@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory): def getContext(self): return self._context + + +def _idnaBytes(text): + """ + Convert some text typed by a human into some ASCII bytes. This is a + copy of twisted.internet._idna._idnaBytes. For documentation, see the + twisted documentation. + """ + try: + import idna + except ImportError: + return text.encode("idna") + else: + return idna.encode(text) + + +def _tolerateErrors(wrapped): + """ + Wrap up an info_callback for pyOpenSSL so that if something goes wrong + the error is immediately logged and the connection is dropped if possible. + This is a copy of twisted.internet._sslverify._tolerateErrors. For + documentation, see the twisted documentation. + """ + + def infoCallback(connection, where, ret): + try: + return wrapped(connection, where, ret) + except: # noqa: E722, taken from the twisted implementation + f = Failure() + logger.exception("Error during info_callback") + connection.get_app_data().failVerification(f) + + return infoCallback + + +@implementer(IOpenSSLClientConnectionCreator) +class ClientTLSOptions(object): + """ + Client creator for TLS without certificate identity verification. This is a + copy of twisted.internet._sslverify.ClientTLSOptions with the identity + verification left out. For documentation, see the twisted documentation. + """ + + def __init__(self, hostname, ctx): + self._ctx = ctx + self._hostname = hostname + self._hostnameBytes = _idnaBytes(hostname) + ctx.set_info_callback( + _tolerateErrors(self._identityVerifyingInfoCallback) + ) + + def clientConnectionForTLS(self, tlsProtocol): + context = self._ctx + connection = SSL.Connection(context, None) + connection.set_app_data(tlsProtocol) + return connection + + def _identityVerifyingInfoCallback(self, connection, where, ret): + if where & SSL.SSL_CB_HANDSHAKE_START: + connection.set_tlsext_host_name(self._hostnameBytes) + + +class ClientTLSOptionsFactory(object): + """Factory for Twisted ClientTLSOptions that are used to make connections + to remote servers for federation.""" + + def __init__(self, config): + # We don't use config options yet + pass + + def get_options(self, host): + return ClientTLSOptions( + host.decode('utf-8'), + CertificateOptions(verify=False).getContext() + ) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 668b4f517d..c20a32096a 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/" @defer.inlineCallbacks -def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): +def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): """Fetch the keys for a remote server.""" factory = SynapseKeyClientFactory() factory.path = path factory.host = server_name endpoint = matrix_federation_endpoint( - reactor, server_name, ssl_context_factory, timeout=30 + reactor, server_name, tls_client_options_factory, timeout=30 ) for i in range(5): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e95b9fb43e..30e2742102 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -512,7 +512,7 @@ class Keyring(object): continue (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_server_context_factory, + server_name, self.hs.tls_client_options_factory, path=(b"/_matrix/key/v2/server/%s" % ( urllib.quote(requested_key_id), )).encode("ascii"), @@ -655,7 +655,7 @@ class Keyring(object): # Try to fetch the key from the remote server. (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_server_context_factory + server_name, self.hs.tls_client_options_factory ) # Check the response. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index b32f64e729..6baeccca38 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -20,7 +20,7 @@ from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.types import UserID, get_domain_from_id @@ -83,6 +83,14 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): 403, "Creation event's room_id domain does not match sender's" ) + + room_version = event.content.get("room_version", "1") + if room_version not in KNOWN_ROOM_VERSIONS: + raise AuthError( + 403, + "room appears to have unsupported version %s" % ( + room_version, + )) # FIXME logger.debug("Allowing! %s", event) return diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7550e11b6e..c9f3c2d352 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -25,7 +25,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import Membership +from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership from synapse.api.errors import ( CodeMessageException, FederationDeniedError, @@ -518,10 +518,10 @@ class FederationClient(FederationBase): description, destination, exc_info=1, ) - raise RuntimeError("Failed to %s via any server", description) + raise RuntimeError("Failed to %s via any server" % (description, )) def make_membership_event(self, destinations, room_id, user_id, membership, - content={},): + content, params): """ Creates an m.room.member event, with context, without participating in the room. @@ -537,8 +537,10 @@ class FederationClient(FederationBase): user_id (str): The user whose membership is being evented. membership (str): The "membership" property of the event. Must be one of "join" or "leave". - content (object): Any additional data to put into the content field + content (dict): Any additional data to put into the content field of the event. + params (dict[str, str|Iterable[str]]): Query parameters to include in the + request. Return: Deferred: resolves to a tuple of (origin (str), event (object)) where origin is the remote homeserver which generated the event. @@ -558,10 +560,12 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership + destination, room_id, user_id, membership, params, ) - pdu_dict = ret["event"] + pdu_dict = ret.get("event", None) + if not isinstance(pdu_dict, dict): + raise InvalidResponseError("Bad 'event' field in response") logger.debug("Got response to make_%s: %s", membership, pdu_dict) @@ -605,6 +609,26 @@ class FederationClient(FederationBase): Fails with a ``RuntimeError`` if no servers were reachable. """ + def check_authchain_validity(signed_auth_chain): + for e in signed_auth_chain: + if e.type == EventTypes.Create: + create_event = e + break + else: + raise InvalidResponseError( + "no %s in auth chain" % (EventTypes.Create,), + ) + + # the room version should be sane. + room_version = create_event.content.get("room_version", "1") + if room_version not in KNOWN_ROOM_VERSIONS: + # This shouldn't be possible, because the remote server should have + # rejected the join attempt during make_join. + raise InvalidResponseError( + "room appears to have unsupported version %s" % ( + room_version, + )) + @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() @@ -661,7 +685,7 @@ class FederationClient(FederationBase): for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - auth_chain.sort(key=lambda e: e.depth) + check_authchain_validity(signed_auth) defer.returnValue({ "state": signed_state, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bf89d568af..3e0cd294a1 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -27,14 +27,24 @@ from twisted.internet.abstract import isIPAddress from twisted.python import failure from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError +from synapse.api.errors import ( + AuthError, + FederationError, + IncompatibleRoomVersionError, + NotFoundError, + SynapseError, +) from synapse.crypto.event_signing import compute_event_signature from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.replication.http.federation import ( + ReplicationFederationSendEduRestServlet, + ReplicationGetQueryRestServlet, +) from synapse.types import get_domain_from_id -from synapse.util import async +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.logutils import log_function @@ -61,8 +71,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler - self._server_linearizer = async.Linearizer("fed_server") - self._transaction_linearizer = async.Linearizer("fed_txn_handler") + self._server_linearizer = Linearizer("fed_server") + self._transaction_linearizer = Linearizer("fed_txn_handler") self.transaction_actions = TransactionActions(self.store) @@ -194,7 +204,7 @@ class FederationServer(FederationBase): event_id, f.getTraceback().rstrip(), ) - yield async.concurrently_execute( + yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT, ) @@ -323,12 +333,21 @@ class FederationServer(FederationBase): defer.returnValue((200, resp)) @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id): + def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) + + room_version = yield self.store.get_room_version(room_id) + if room_version not in supported_versions: + logger.warn("Room version %s not in %s", room_version, supported_versions) + raise IncompatibleRoomVersionError(room_version=room_version) + pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + defer.returnValue({ + "event": pdu.get_pdu_json(time_now), + "room_version": room_version, + }) @defer.inlineCallbacks def on_invite_request(self, origin, content): @@ -745,6 +764,8 @@ class FederationHandlerRegistry(object): if edu_type in self.edu_handlers: raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + logger.info("Registering federation EDU handler for %r", edu_type) + self.edu_handlers[edu_type] = handler def register_query_handler(self, query_type, handler): @@ -763,6 +784,8 @@ class FederationHandlerRegistry(object): "Already have a Query handler for %s" % (query_type,) ) + logger.info("Registering federation query handler for %r", query_type) + self.query_handlers[query_type] = handler @defer.inlineCallbacks @@ -785,3 +808,49 @@ class FederationHandlerRegistry(object): raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) + + +class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): + """A FederationHandlerRegistry for worker processes. + + When receiving EDU or queries it will check if an appropriate handler has + been registered on the worker, if there isn't one then it calls off to the + master process. + """ + + def __init__(self, hs): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + super(ReplicationFederationHandlerRegistry, self).__init__() + + def on_edu(self, edu_type, origin, content): + """Overrides FederationHandlerRegistry + """ + handler = self.edu_handlers.get(edu_type) + if handler: + return super(ReplicationFederationHandlerRegistry, self).on_edu( + edu_type, origin, content, + ) + + return self._send_edu( + edu_type=edu_type, + origin=origin, + content=content, + ) + + def on_query(self, query_type, args): + """Overrides FederationHandlerRegistry + """ + handler = self.query_handlers.get(query_type) + if handler: + return handler(args) + + return self._get_query_client( + query_type=query_type, + args=args, + ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 78f9d40a3a..94d7423d01 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -26,6 +26,8 @@ from synapse.api.errors import FederationDeniedError, HttpResponseException from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.metrics import ( LaterGauge, + event_processing_loop_counter, + event_processing_loop_room_count, events_processed_counter, sent_edus_counter, sent_transactions_counter, @@ -56,6 +58,7 @@ class TransactionQueue(object): """ def __init__(self, hs): + self.hs = hs self.server_name = hs.hostname self.store = hs.get_datastore() @@ -253,7 +256,13 @@ class TransactionQueue(object): synapse.metrics.event_processing_last_ts.labels( "federation_sender").set(ts) - events_processed_counter.inc(len(events)) + events_processed_counter.inc(len(events)) + + event_processing_loop_room_count.labels( + "federation_sender" + ).inc(len(events_by_room)) + + event_processing_loop_counter.labels("federation_sender").inc() synapse.metrics.event_processing_positions.labels( "federation_sender").set(next_token) @@ -300,6 +309,9 @@ class TransactionQueue(object): Args: states (list(UserPresenceState)) """ + if not self.hs.config.use_presence: + # No-op if presence is disabled. + return # First we queue up the new presence by user ID, so multiple presence # updates in quick successtion are correctly handled diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 4529d454af..b4fbe2c9d5 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -195,7 +195,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_membership_event(self, destination, room_id, user_id, membership): + def make_membership_event(self, destination, room_id, user_id, membership, params): """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -205,6 +205,8 @@ class TransportLayerClient(object): room_id (str): room to join/leave user_id (str): user to be joined/left membership (str): one of join/leave + params (dict[str, str|Iterable[str]]): Query parameters to include in the + request. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -241,6 +243,7 @@ class TransportLayerClient(object): content = yield self.client.get_json( destination=destination, path=path, + args=params, retry_on_dns_fail=retry_on_dns_fail, timeout=20000, ignore_backoff=ignore_backoff, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index eae5f2b427..77969a4f38 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -190,6 +190,41 @@ def _parse_auth_header(header_bytes): class BaseFederationServlet(object): + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods have the signature: + + on_<METHOD>(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ REQUIRE_AUTH = True def __init__(self, handler, authenticator, ratelimiter, server_name): @@ -204,6 +239,18 @@ class BaseFederationServlet(object): @defer.inlineCallbacks @functools.wraps(func) def new_func(request, *args, **kwargs): + """ A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: (response code, response object) as returned + by the callback method. None if the request has already been handled. + """ content = None if request.method in ["PUT", "POST"]: # TODO: Handle other method types? other content types? @@ -384,9 +431,31 @@ class FederationMakeJoinServlet(BaseFederationServlet): PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" @defer.inlineCallbacks - def on_GET(self, origin, content, query, context, user_id): + def on_GET(self, origin, _content, query, context, user_id): + """ + Args: + origin (unicode): The authenticated server_name of the calling server + + _content (None): (GETs don't have bodies) + + query (dict[bytes, list[bytes]]): Query params from the request. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + """ + versions = query.get(b'ver') + if versions is not None: + supported_versions = [v.decode("utf-8") for v in versions] + else: + supported_versions = ["1"] + content = yield self.handler.on_make_join_request( origin, context, user_id, + supported_versions=supported_versions, ) defer.returnValue((200, content)) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ee41aed69e..f0f89af7dc 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -23,6 +23,10 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes +from synapse.metrics import ( + event_processing_loop_counter, + event_processing_loop_room_count, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.metrics import Measure @@ -136,6 +140,12 @@ class ApplicationServicesHandler(object): events_processed_counter.inc(len(events)) + event_processing_loop_room_count.labels( + "appservice_sender" + ).inc(len(events_by_room)) + + event_processing_loop_counter.labels("appservice_sender").inc() + synapse.metrics.event_processing_lag.labels( "appservice_sender").set(now - ts) synapse.metrics.event_processing_last_ts.labels( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 184eef09d0..4a81bd2ba9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -520,7 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking(user_id) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -734,7 +734,6 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): - yield self._check_mau_limits() auth_api = self.hs.get_auth() user_id = None try: @@ -743,6 +742,7 @@ class AuthHandler(BaseHandler): auth_api.validate_macaroon(macaroon, "login", True, user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + yield self.auth.check_auth_blocking(user_id) defer.returnValue(user_id) @defer.inlineCallbacks @@ -828,12 +828,26 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def delete_threepid(self, user_id, medium, address): + """Attempts to unbind the 3pid on the identity servers and deletes it + from the local database. + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[bool]: Returns True if successfully unbound the 3pid on + the identity server, False if identity server doesn't support the + unbind API. + """ + # 'Canonicalise' email addresses as per above if medium == 'email': address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - yield identity_handler.unbind_threepid( + result = yield identity_handler.try_unbind_threepid( user_id, { 'medium': medium, @@ -841,10 +855,10 @@ class AuthHandler(BaseHandler): }, ) - ret = yield self.store.user_delete_threepid( + yield self.store.user_delete_threepid( user_id, medium, address, ) - defer.returnValue(ret) + defer.returnValue(result) def _save_session(self, session): # TODO: Persistent storage @@ -907,19 +921,6 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Ensure that if mau blocking is enabled that invalid users cannot - log in. - """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index b3c5a9ee64..b078df4a76 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler): erase_data (bool): whether to GDPR-erase the user's data Returns: - Deferred + Deferred[bool]: True if identity server supports removing + threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler): # leave the user still active so they can try again. # Ideally we would prevent password resets and then do this in the # background thread. + + # This will be set to false if the identity server doesn't support + # unbinding + identity_server_supports_unbinding = True + threepids = yield self.store.user_get_threepids(user_id) for threepid in threepids: try: - yield self._identity_handler.unbind_threepid( + result = yield self._identity_handler.try_unbind_threepid( user_id, { 'medium': threepid['medium'], 'address': threepid['address'], }, ) + identity_server_supports_unbinding &= result except Exception: # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") @@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler): # parts users from rooms (if it isn't already running) self._start_user_parting() + defer.returnValue(identity_server_supports_unbinding) + def _start_user_parting(self): """ Start the process that goes through the table of users diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2d44f15da3..9e017116a9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import FederationDeniedError from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 533b82c783..3dd107a285 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -30,7 +30,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.api.constants import ( + KNOWN_ROOM_VERSIONS, + EventTypes, + Membership, + RejectedReason, +) from synapse.api.errors import ( AuthError, CodeMessageException, @@ -44,10 +49,15 @@ from synapse.crypto.event_signing import ( compute_event_signature, ) from synapse.events.validator import EventValidator +from synapse.replication.http.federation import ( + ReplicationCleanRoomRestServlet, + ReplicationFederationSendEventsRestServlet, +) +from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function @@ -86,6 +96,18 @@ class FederationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._server_notices_mxid = hs.config.server_notices_mxid + self.config = hs.config + self.http_client = hs.get_simple_http_client() + + self._send_events_to_master = ( + ReplicationFederationSendEventsRestServlet.make_client(hs) + ) + self._notify_user_membership_change = ( + ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + ) + self._clean_room_for_join_client = ( + ReplicationCleanRoomRestServlet.make_client(hs) + ) # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -922,6 +944,9 @@ class FederationHandler(BaseHandler): joinee, "join", content, + params={ + "ver": KNOWN_ROOM_VERSIONS, + }, ) # This shouldn't happen, because the RoomMemberHandler has a @@ -1150,7 +1175,7 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - yield self._persist_events([(event, context)]) + yield self.persist_events_and_notify([(event, context)]) defer.returnValue(event) @@ -1181,19 +1206,20 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - yield self._persist_events([(event, context)]) + yield self.persist_events_and_notify([(event, context)]) defer.returnValue(event) @defer.inlineCallbacks def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, - content={},): + content={}, params=None): origin, pdu = yield self.federation_client.make_membership_event( target_hosts, room_id, user_id, membership, content, + params=params, ) logger.debug("Got response to make_%s: %s", membership, pdu) @@ -1423,7 +1449,7 @@ class FederationHandler(BaseHandler): event, context ) - yield self._persist_events( + yield self.persist_events_and_notify( [(event, context)], backfilled=backfilled, ) @@ -1461,7 +1487,7 @@ class FederationHandler(BaseHandler): ], consumeErrors=True, )) - yield self._persist_events( + yield self.persist_events_and_notify( [ (ev_info["event"], context) for ev_info, context in zip(event_infos, contexts) @@ -1549,7 +1575,7 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - yield self._persist_events( + yield self.persist_events_and_notify( [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) @@ -1560,7 +1586,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - yield self._persist_events( + yield self.persist_events_and_notify( [(event, new_event_context)], ) @@ -2288,7 +2314,7 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.hs.get_simple_http_client().get_json( + response = yield self.http_client.get_json( url, {"public_key": public_key} ) @@ -2301,7 +2327,7 @@ class FederationHandler(BaseHandler): raise AuthError(403, "Third party certificate was invalid") @defer.inlineCallbacks - def _persist_events(self, event_and_contexts, backfilled=False): + def persist_events_and_notify(self, event_and_contexts, backfilled=False): """Persists events and tells the notifier/pushers about them, if necessary. @@ -2313,14 +2339,21 @@ class FederationHandler(BaseHandler): Returns: Deferred """ - max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, - ) + if self.config.worker_app: + yield self._send_events_to_master( + store=self.store, + event_and_contexts=event_and_contexts, + backfilled=backfilled + ) + else: + max_stream_id = yield self.store.persist_events( + event_and_contexts, + backfilled=backfilled, + ) - if not backfilled: # Never notify for backfilled events - for event, _ in event_and_contexts: - self._notify_persisted_event(event, max_stream_id) + if not backfilled: # Never notify for backfilled events + for event, _ in event_and_contexts: + self._notify_persisted_event(event, max_stream_id) def _notify_persisted_event(self, event, max_stream_id): """Checks to see if notifier/pushers should be notified about the @@ -2353,15 +2386,30 @@ class FederationHandler(BaseHandler): extra_users=extra_users ) - logcontext.run_in_background( - self.pusher_pool.on_new_notifications, + self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) def _clean_room_for_join(self, room_id): - return self.store.clean_room_for_join(room_id) + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Args: + room_id (str) + """ + if self.config.worker_app: + return self._clean_room_for_join_client(room_id) + else: + return self.store.clean_room_for_join(room_id) def user_joined_room(self, user, room_id): """Called when a new user has joined the room """ - return user_joined_room(self.distributor, user, room_id) + if self.config.worker_app: + return self._notify_user_membership_change( + room_id=room_id, + user_id=user.to_string(), + change="joined", + ) + else: + return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 1d36d967c3..5feb3f22a6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def unbind_threepid(self, mxid, threepid): - """ - Removes a binding from an identity server + def try_unbind_threepid(self, mxid, threepid): + """Removes a binding from an identity server + Args: mxid (str): Matrix user ID of binding to be removed threepid (dict): Dict with medium & address of binding to be removed + Raises: + SynapseError: If we failed to contact the identity server + Returns: - Deferred[bool]: True on success, otherwise False + Deferred[bool]: True on success, otherwise False if the identity + server doesn't support unbinding """ logger.debug("unbinding threepid %r from %s", threepid, mxid) if not self.trusted_id_servers: @@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler): content=content, destination_is=id_server, ) - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + try: + yield self.http_client.post_json_get_json( + url, + content, + headers, + ) + except HttpResponseException as e: + if e.code in (400, 404, 501,): + # The remote server probably doesn't support unbinding (yet) + logger.warn("Received %d response while unbinding threepid", e.code) + defer.returnValue(False) + else: + logger.error("Failed to unbind threepid on identity server: %s", e) + raise SynapseError(502, "Failed to contact identity server") + defer.returnValue(True) @defer.inlineCallbacks diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 40e7580a61..e009395207 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client @@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): + # If presence is disabled, return an empty list + if not self.hs.config.use_presence: + defer.returnValue([]) + states = yield presence_handler.get_states( [m.user_id for m in room_members], as_event=True, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 39d7724778..e484061cc0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -25,17 +25,24 @@ from twisted.internet import defer from twisted.internet.defer import succeed from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + NotFoundError, + SynapseError, +) from synapse.api.urls import ConsentURIBuilder from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.replication.http.send_event import send_event_to_master +from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.types import RoomAlias, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -82,28 +89,85 @@ class MessageHandler(object): defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): + def get_state_events( + self, user_id, room_id, types=None, filtered_types=None, + at_token=None, is_guest=False, + ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has - left the room return the state events from when they left. + left the room return the state events from when they left. If an explicit + 'at' parameter is passed, return the state events as of that event, if + visible. Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + at_token(StreamToken|None): the stream token of the at which we are requesting + the stats. If the user is not allowed to view the state as of that + stream token, we raise a 403 SynapseError. If None, returns the current + state based on the current_state_events table. + is_guest(bool): whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] + Raises: + NotFoundError (404) if the at token does not yield an event + + AuthError (403) if the user doesn't have permission to view + members of this room. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + if at_token: + # FIXME this claims to get the state at a stream position, but + # get_recent_events_for_room operates by topo ordering. This therefore + # does not reliably give you the state at the given stream position. + # (https://github.com/matrix-org/synapse/issues/3305) + last_events, _ = yield self.store.get_recent_events_for_room( + room_id, end_token=at_token.room_key, limit=1, + ) - if membership == Membership.JOIN: - room_state = yield self.state.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if not last_events: + raise NotFoundError("Can't find event for token %s" % (at_token, )) + + visible_events = yield filter_events_for_client( + self.store, user_id, last_events, ) - room_state = room_state[membership_event_id] + + event = last_events[0] + if visible_events: + room_state = yield self.store.get_state_for_events( + [event.event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[event.event_id] + else: + raise AuthError( + 403, + "User %s not allowed to view events in room %s at token %s" % ( + user_id, room_id, at_token, + ) + ) + else: + membership, membership_event_id = ( + yield self.auth.check_in_room_or_world_readable( + room_id, user_id, + ) + ) + + if membership == Membership.JOIN: + state_ids = yield self.store.get_filtered_current_state_ids( + room_id, types, filtered_types=filtered_types, + ) + room_state = yield self.store.get_events(state_ids.values()) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( @@ -171,7 +235,7 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config - self.http_client = hs.get_simple_http_client() + self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) @@ -212,10 +276,14 @@ class EventCreationHandler(object): where *hashes* is a map from algorithm to hash. If None, they will be requested from the database. - + Raises: + ResourceLimitError if server is blocked to some resource being + exceeded Returns: Tuple of created event (FrozenEvent), Context """ + yield self.auth.check_auth_blocking(requester.user.to_string()) + builder = self.event_builder_factory.new(event_dict) self.validator.validate_new(builder) @@ -559,12 +627,9 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - yield send_event_to_master( - clock=self.hs.get_clock(), + yield self.send_event_to_master( + event_id=event.event_id, store=self.store, - client=self.http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, requester=requester, event=event, context=context, @@ -713,11 +778,8 @@ class EventCreationHandler(object): event, context=context ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id + self.pusher_pool.on_new_notifications( + event_stream_id, max_stream_id, ) def _notify(): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b2849783ed..5170d093e3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -18,11 +18,11 @@ import logging from twisted.internet import defer from twisted.python.failure import Failure -from synapse.api.constants import Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from synapse.types import RoomStreamToken -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -251,6 +251,33 @@ class PaginationHandler(object): is_peeking=(member_event_id is None), ) + state = None + if event_filter and event_filter.lazy_load_members(): + # TODO: remove redundant members + + types = [ + (EventTypes.Member, state_key) + for state_key in set( + event.sender # FIXME: we also care about invite targets etc. + for event in events + ) + ] + + state_ids = yield self.store.get_state_ids_for_event( + events[0].event_id, types=types, + ) + + if state_ids: + state = yield self.store.get_events(list(state_ids.values())) + + if state: + state = yield filter_events_for_client( + self.store, + user_id, + state.values(), + is_peeking=(member_event_id is None), + ) + time_now = self.clock.time_msec() chunk = { @@ -262,4 +289,10 @@ class PaginationHandler(object): "end": next_token.to_string(), } + if state: + chunk["state"] = [ + serialize_event(e, time_now, as_client_event) + for e in state + ] + defer.returnValue(chunk) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3732830194..ba3856674d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function @@ -95,6 +95,7 @@ class PresenceHandler(object): Args: hs (synapse.server.HomeServer): """ + self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() @@ -230,6 +231,10 @@ class PresenceHandler(object): earlier than they should when synapse is restarted. This affect of this is some spurious presence changes that will self-correct. """ + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) @@ -390,6 +395,10 @@ class PresenceHandler(object): """We've seen the user do something that indicates they're interacting with the app. """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + user_id = user.to_string() bump_active_time_counter.inc() @@ -419,6 +428,11 @@ class PresenceHandler(object): Useful for streams that are not associated with an actual client that is being used by a user. """ + # Override if it should affect the user's presence, if presence is + # disabled. + if not self.hs.config.use_presence: + affect_presence = False + if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 @@ -464,13 +478,16 @@ class PresenceHandler(object): Returns: set(str): A set of user_id strings. """ - syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids + if self.hs.config.use_presence: + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + else: + return set() @defer.inlineCallbacks def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 995460f82a..32108568c6 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from ._base import BaseHandler diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index cb905a3903..a6f3181f09 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.types import get_domain_from_id from synapse.util import logcontext -from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler @@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r["room_id"] for r in receipts])) - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) - # Note that the min here shouldn't be relied upon to be accurate. - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + # Note that the min here shouldn't be relied upon to be accurate. + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids, + ) - defer.returnValue(True) + defer.returnValue(True) @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 289704b241..f03ee1476b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,7 @@ from synapse.api.errors import ( ) from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler @@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield self._check_mau_limits() + + yield self.auth.check_auth_blocking() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() need_register = True try: @@ -533,16 +534,3 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) - - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Do not accept registrations if monthly active user limits exceeded - and limiting is enabled - """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise RegistrationError( - 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED - ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7b7804d9b2..c3f820b975 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,9 +21,17 @@ import math import string from collections import OrderedDict +from six import string_types + from twisted.internet import defer -from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from synapse.api.constants import ( + DEFAULT_ROOM_VERSION, + KNOWN_ROOM_VERSIONS, + EventTypes, + JoinRules, + RoomCreationPreset, +) from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -90,15 +98,34 @@ class RoomCreationHandler(BaseHandler): Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. + ResourceLimitError if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() + self.auth.check_auth_blocking(user_id) + if not self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: yield self.ratelimit(requester) + room_version = config.get("room_version", DEFAULT_ROOM_VERSION) + if not isinstance(room_version, string_types): + raise SynapseError( + 400, + "room_version must be a string", + Codes.BAD_JSON, + ) + + if room_version not in KNOWN_ROOM_VERSIONS: + raise SynapseError( + 400, + "Your homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + if "room_alias_name" in config: for wchar in string.whitespace: if wchar in config["room_alias_name"]: @@ -184,6 +211,9 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) + # override any attempt to set room versions via the creation_content + creation_content["room_version"] = room_version + room_member_handler = self.hs.get_room_member_handler() yield self._send_events_for_new_room( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 828229f5c3..37e41afd61 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.types import ThirdPartyInstanceID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..fb94b5d7d4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -30,7 +30,7 @@ import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.types import RoomID, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 22d8b4b0d3..acc6eb8099 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -20,16 +20,24 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler from synapse.replication.http.membership import ( - get_or_register_3pid_guest, - notify_user_membership_change, - remote_join, - remote_reject_invite, + ReplicationRegister3PIDGuestRestServlet as Repl3PID, + ReplicationRemoteJoinRestServlet as ReplRemoteJoin, + ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, + ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): + def __init__(self, hs): + super(RoomMemberWorkerHandler, self).__init__(hs) + + self._get_register_3pid_client = Repl3PID.make_client(hs) + self._remote_join_client = ReplRemoteJoin.make_client(hs) + self._remote_reject_client = ReplRejectInvite.make_client(hs) + self._notify_change_client = ReplJoinedLeft.make_client(hs) + @defer.inlineCallbacks def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join @@ -37,10 +45,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - ret = yield remote_join( - self.simple_http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + ret = yield self._remote_join_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, @@ -55,10 +60,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): """Implements RoomMemberHandler._remote_reject_invite """ - return remote_reject_invite( - self.simple_http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + return self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, @@ -68,10 +70,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return notify_user_membership_change( - self.simple_http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + return self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined", @@ -80,10 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return notify_user_membership_change( - self.simple_http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + return self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left", @@ -92,10 +88,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): """Implements RoomMemberHandler.get_or_register_3pid_guest """ - return get_or_register_3pid_guest( - self.simple_http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + return self._get_register_3pid_client( requester=requester, medium=medium, address=address, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..648debc8aa 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.types import RoomStreamToken -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache @@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "ephemeral", "account_data", "unread_notifications", + "summary", ])): __slots__ = [] @@ -184,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ class SyncHandler(object): def __init__(self, hs): + self.hs_config = hs.config self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() @@ -191,6 +193,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() + self.auth = hs.get_auth() # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -198,19 +201,27 @@ class SyncHandler(object): max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. Returns: - A Deferred SyncResult. + Deferred[SyncResult] """ - return self.response_cache.wrap( + # If the user is not part of the mau group, then check that limits have + # not been exceeded (if not part of the group by this point, almost certain + # auth_blocking will occur) + user_id = sync_config.user.to_string() + yield self.auth.check_auth_blocking(user_id) + + res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, since_token, timeout, full_state, ) + defer.returnValue(res) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, @@ -495,9 +506,141 @@ class SyncHandler(object): defer.returnValue(state) @defer.inlineCallbacks + def compute_summary(self, room_id, sync_config, batch, state, now_token): + """ Works out a room summary block for this room, summarising the number + of joined members in the room, and providing the 'hero' members if the + room has no name so clients can consistently name rooms. Also adds + state events to 'state' if needed to describe the heroes. + + Args: + room_id(str): + sync_config(synapse.handlers.sync.SyncConfig): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + state(dict): dict of (type, state_key) -> Event as returned by + compute_state_delta + now_token(str): Token of the end of the current batch. + + Returns: + A deferred dict describing the room summary + """ + + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 + last_events, _ = yield self.store.get_recent_event_ids_for_room( + room_id, end_token=now_token.room_key, limit=1, + ) + + if not last_events: + defer.returnValue(None) + return + + last_event = last_events[-1] + state_ids = yield self.store.get_state_ids_for_event( + last_event.event_id, [ + (EventTypes.Member, None), + (EventTypes.Name, ''), + (EventTypes.CanonicalAlias, ''), + ] + ) + + member_ids = { + state_key: event_id + for (t, state_key), event_id in state_ids.iteritems() + if t == EventTypes.Member + } + name_id = state_ids.get((EventTypes.Name, '')) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + + summary = {} + + # FIXME: it feels very heavy to load up every single membership event + # just to calculate the counts. + member_events = yield self.store.get_events(member_ids.values()) + + joined_user_ids = [] + invited_user_ids = [] + + for ev in member_events.values(): + if ev.content.get("membership") == Membership.JOIN: + joined_user_ids.append(ev.state_key) + elif ev.content.get("membership") == Membership.INVITE: + invited_user_ids.append(ev.state_key) + + # TODO: only send these when they change. + summary["m.joined_member_count"] = len(joined_user_ids) + summary["m.invited_member_count"] = len(invited_user_ids) + + if name_id or canonical_alias_id: + defer.returnValue(summary) + + # FIXME: order by stream ordering, not alphabetic + + me = sync_config.user.to_string() + if (joined_user_ids or invited_user_ids): + summary['m.heroes'] = sorted( + [ + user_id + for user_id in (joined_user_ids + invited_user_ids) + if user_id != me + ] + )[0:5] + else: + summary['m.heroes'] = sorted( + [user_id for user_id in member_ids.keys() if user_id != me] + )[0:5] + + if not sync_config.filter_collection.lazy_load_members(): + defer.returnValue(summary) + + # ensure we send membership events for heroes if needed + cache_key = (sync_config.user.to_string(), sync_config.device_id) + cache = self.get_lazy_loaded_members_cache(cache_key) + + # track which members the client should already know about via LL: + # Ones which are already in state... + existing_members = set( + user_id for (typ, user_id) in state.keys() + if typ == EventTypes.Member + ) + + # ...or ones which are in the timeline... + for ev in batch.events: + if ev.type == EventTypes.Member: + existing_members.add(ev.state_key) + + # ...and then ensure any missing ones get included in state. + missing_hero_event_ids = [ + member_ids[hero_id] + for hero_id in summary['m.heroes'] + if ( + cache.get(hero_id) != member_ids[hero_id] and + hero_id not in existing_members + ) + ] + + missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = missing_hero_state.values() + + for s in missing_hero_state: + cache.set(s.state_key, s.event_id) + state[(EventTypes.Member, s.state_key)] = s + + defer.returnValue(summary) + + def get_lazy_loaded_members_cache(self, cache_key): + cache = self.lazy_loaded_members_cache.get(cache_key) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) + self.lazy_loaded_members_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + return cache + + @defer.inlineCallbacks def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, full_state): - """ Works out the differnce in state between the start of the timeline + """ Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -511,7 +654,7 @@ class SyncHandler(object): full_state(bool): Whether to force returning the full state. Returns: - A deferred new event dictionary + A deferred dict of (type, state_key) -> Event """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -609,13 +752,7 @@ class SyncHandler(object): if lazy_load_members and not include_redundant_members: cache_key = (sync_config.user.to_string(), sync_config.device_id) - cache = self.lazy_loaded_members_cache.get(cache_key) - if cache is None: - logger.debug("creating LruCache for %r", cache_key) - cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) - self.lazy_loaded_members_cache[cache_key] = cache - else: - logger.debug("found LruCache for %r", cache_key) + cache = self.get_lazy_loaded_members_cache(cache_key) # if it's a new sync sequence, then assume the client has had # amnesia and doesn't want any recent lazy-loaded members @@ -724,7 +861,7 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) - if not block_all_presence_data: + if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_users ) @@ -1416,7 +1553,6 @@ class SyncHandler(object): if events == [] and tags is None: return - since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1459,6 +1595,18 @@ class SyncHandler(object): full_state=full_state ) + summary = {} + if ( + sync_config.filter_collection.lazy_load_members() and + ( + any(ev.type == EventTypes.Member for ev in batch.events) or + since_token is None + ) + ): + summary = yield self.compute_summary( + room_id, sync_config, batch, state, now_token + ) + if room_builder.rtype == "joined": unread_notifications = {} room_sync = JoinedSyncResult( @@ -1468,6 +1616,7 @@ class SyncHandler(object): ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + summary=summary, ) if room_sync or always_include: diff --git a/synapse/http/client.py b/synapse/http/client.py index 3771e0b3f6..ab4fbf59b2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index d65daa72bb..b0c9369519 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError logger = logging.getLogger(__name__) - SERVER_CACHE = {} # our record of an individual server which can be tried to reach a destination. @@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name): return host, port -def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, +def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. Args: reactor: Twisted reactor. destination (bytes): The name of the server to connect to. - ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory - which generates SSL contexts to use for TLS. + tls_client_options_factory + (synapse.crypto.context_factory.ClientTLSOptionsFactory): + Factory which generates TLS options for client connections. timeout (int): connection timeout in seconds """ @@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, if timeout is not None: endpoint_kw_args.update(timeout=timeout) - if ssl_context_factory is None: + if tls_client_options_factory is None: transport_endpoint = HostnameEndpoint default_port = 8008 else: def transport_endpoint(reactor, host, port, timeout): return wrapClientTLS( - ssl_context_factory, + tls_client_options_factory.get_options(host), HostnameEndpoint(reactor, host, port, timeout=timeout)) default_port = 8448 diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index bf1aa29502..44b61e70a4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,7 +43,7 @@ from synapse.api.errors import ( from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint from synapse.util import logcontext -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) @@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3 class MatrixFederationEndpointFactory(object): def __init__(self, hs): - self.tls_server_context_factory = hs.tls_server_context_factory + self.tls_client_options_factory = hs.tls_client_options_factory def endpointForURI(self, uri): destination = uri.netloc return matrix_federation_endpoint( reactor, destination, timeout=10, - ssl_context_factory=self.tls_server_context_factory + tls_client_options_factory=self.tls_client_options_factory ) @@ -439,7 +439,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def get_json(self, destination, path, args={}, retry_on_dns_fail=True, + def get_json(self, destination, path, args=None, retry_on_dns_fail=True, timeout=None, ignore_backoff=False): """ GETs some json from the given host homeserver and path @@ -447,7 +447,7 @@ class MatrixFederationHttpClient(object): destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. - args (dict): A dictionary used to create query strings, defaults to + args (dict|None): A dictionary used to create query strings, defaults to None. timeout (int): How long to try (in ms) the destination for before giving up. None indicates no timeout and that the request will @@ -702,6 +702,9 @@ def check_content_type_is_json(headers): def encode_query_args(args): + if args is None: + return b"" + encoded_args = {} for k, vs in args.items(): if isinstance(vs, string_types): diff --git a/synapse/http/server.py b/synapse/http/server.py index 6dacb31037..2d5c23e673 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso from twisted.internet import defer from twisted.python import failure -from twisted.web import resource, server +from twisted.web import resource from twisted.web.server import NOT_DONE_YET +from twisted.web.static import NoRangeStaticProducer from twisted.web.util import redirectTo import synapse.events @@ -37,10 +38,13 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.request_metrics import requests_counter from synapse.util.caches import intern_dict -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.metrics import Measure +from synapse.util.logcontext import preserve_fn + +if PY3: + from io import BytesIO +else: + from cStringIO import StringIO as BytesIO logger = logging.getLogger(__name__) @@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> def wrap_json_request_handler(h): """Wraps a request handler method with exception handling. - Also adds logging as per wrap_request_handler_with_logging. + Also does the wrapping with request.processing as per wrap_async_request_handler. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. The handler must return a deferred. If the deferred succeeds we assume that a response has been sent. If the deferred fails with a SynapseError we use @@ -108,24 +111,23 @@ def wrap_json_request_handler(h): pretty_print=_request_user_agent_is_curl(request), ) - return wrap_request_handler_with_logging(wrapped_request_handler) + return wrap_async_request_handler(wrapped_request_handler) def wrap_html_request_handler(h): """Wraps a request handler method with exception handling. - Also adds logging as per wrap_request_handler_with_logging. + Also does the wrapping with request.processing as per wrap_async_request_handler. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. """ def wrapped_request_handler(self, request): d = defer.maybeDeferred(h, self, request) d.addErrback(_return_html_error, request) return d - return wrap_request_handler_with_logging(wrapped_request_handler) + return wrap_async_request_handler(wrapped_request_handler) def _return_html_error(f, request): @@ -170,46 +172,26 @@ def _return_html_error(f, request): finish_request(request) -def wrap_request_handler_with_logging(h): - """Wraps a request handler to provide logging and metrics +def wrap_async_request_handler(h): + """Wraps an async request handler so that it calls request.processing. + + This helps ensure that work done by the request handler after the request is completed + is correctly recorded against the request metrics/logs. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. - As well as calling `request.processing` (which will log the response and - duration for this request), the wrapped request handler will insert the - request id into the logging context. + The handler may return a deferred, in which case the completion of the request isn't + logged until the deferred completes. """ @defer.inlineCallbacks - def wrapped_request_handler(self, request): - """ - Args: - self: - request (synapse.http.site.SynapseRequest): - """ + def wrapped_async_request_handler(self, request): + with request.processing(): + yield h(self, request) - request_id = request.get_request_id() - with LoggingContext(request_id) as request_context: - request_context.request = request_id - with Measure(self.clock, "wrapped_request_handler"): - # we start the request metrics timer here with an initial stab - # at the servlet name. For most requests that name will be - # JsonResource (or a subclass), and JsonResource._async_render - # will update it once it picks a servlet. - servlet_name = self.__class__.__name__ - with request.processing(servlet_name): - with PreserveLoggingContext(request_context): - d = defer.maybeDeferred(h, self, request) - - # record the arrival of the request *after* - # dispatching to the handler, so that the handler - # can update the servlet name in the request - # metrics - requests_counter.labels(request.method, - request.request_metrics.name).inc() - yield d - return wrapped_request_handler + # we need to preserve_fn here, because the synchronous render method won't yield for + # us (obviously) + return preserve_fn(wrapped_async_request_handler) class HttpServer(object): @@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource): """ This gets called by twisted every time someone sends us a request. """ self._async_render(request) - return server.NOT_DONE_YET + return NOT_DONE_YET @wrap_json_request_handler @defer.inlineCallbacks @@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False, return if pretty_print: - json_bytes = (encode_pretty_printed_json(json_object) + "\n" - ).encode("utf-8") + json_bytes = encode_pretty_printed_json(json_object) + b"\n" else: if canonical_json or synapse.events.USE_FROZEN_DICTS: # canonicaljson already encodes to bytes @@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, if send_cors: set_cors_headers(request) - request.write(json_bytes) - finish_request(request) + # todo: we can almost certainly avoid this copy and encode the json straight into + # the bytesIO, but it would involve faffing around with string->bytes wrappers. + bytes_io = BytesIO(json_bytes) + + producer = NoRangeStaticProducer(request, bytes_io) + producer.start() return NOT_DONE_YET diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 69f7085291..a1e4b88e6d 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False): Args: request: the twisted HTTP request. - name (str): the name of the query parameter. + name (bytes/unicode): the name of the query parameter. default (int|None): value to use if the parameter is absent, defaults to None. required (bool): whether to raise a 400 SynapseError if the @@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False): + + if not isinstance(name, bytes): + name = name.encode('ascii') + if name in args: try: return int(args[name][0]) @@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False): Args: request: the twisted HTTP request. - name (str): the name of the query parameter. + name (bytes/unicode): the name of the query parameter. default (bool|None): value to use if the parameter is absent, defaults to None. required (bool): whether to raise a 400 SynapseError if the @@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False): + + if not isinstance(name, bytes): + name = name.encode('ascii') + if name in args: try: return { - "true": True, - "false": False, + b"true": True, + b"false": False, }[args[name][0]] except Exception: message = ( @@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False): def parse_string(request, name, default=None, required=False, - allowed_values=None, param_type="string"): - """Parse a string parameter from the request query string. + allowed_values=None, param_type="string", encoding='ascii'): + """ + Parse a string parameter from the request query string. + + If encoding is not None, the content of the query param will be + decoded to Unicode using the encoding, otherwise it will be encoded Args: request: the twisted HTTP request. - name (str): the name of the query parameter. - default (str|None): value to use if the parameter is absent, defaults - to None. + name (bytes/unicode): the name of the query parameter. + default (bytes/unicode|None): value to use if the parameter is absent, + defaults to None. Must be bytes if encoding is None. required (bool): whether to raise a 400 SynapseError if the parameter is absent, defaults to False. - allowed_values (list[str]): List of allowed values for the string, - or None if any value is allowed, defaults to None + allowed_values (list[bytes/unicode]): List of allowed values for the + string, or None if any value is allowed, defaults to None. Must be + the same type as name, if given. + encoding: The encoding to decode the name to, and decode the string + content with. Returns: - str|None: A string value or the default. + bytes/unicode|None: A string value or the default. Unicode if encoding + was given, bytes otherwise. Raises: SynapseError if the parameter is absent and required, or if the @@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False, is not one of those allowed values. """ return parse_string_from_args( - request.args, name, default, required, allowed_values, param_type, + request.args, name, default, required, allowed_values, param_type, encoding ) def parse_string_from_args(args, name, default=None, required=False, - allowed_values=None, param_type="string"): + allowed_values=None, param_type="string", encoding='ascii'): + + if not isinstance(name, bytes): + name = name.encode('ascii') + if name in args: value = args[name][0] + + if encoding: + value = value.decode(encoding) + if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values) @@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False, message = "Missing %s query parameter %r" % (param_type, name) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) else: + + if encoding and isinstance(default, bytes): + return default.decode(encoding) + return default diff --git a/synapse/http/site.py b/synapse/http/site.py index 5fd30a4c2c..88ed3714f9 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import logging import time @@ -19,8 +18,8 @@ import time from twisted.web.server import Request, Site from synapse.http import redact_uri -from synapse.http.request_metrics import RequestMetrics -from synapse.util.logcontext import ContextResourceUsage, LoggingContext +from synapse.http.request_metrics import RequestMetrics, requests_counter +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) @@ -34,25 +33,43 @@ class SynapseRequest(Request): It extends twisted's twisted.web.server.Request, and adds: * Unique request ID + * A log context associated with the request * Redaction of access_token query-params in __repr__ * Logging at start and end * Metrics to record CPU, wallclock and DB time by endpoint. - It provides a method `processing` which should be called by the Resource - which is handling the request, and returns a context manager. + It also provides a method `processing`, which returns a context manager. If this + method is called, the request won't be logged until the context manager is closed; + this is useful for asynchronous request handlers which may go on processing the + request even after the client has disconnected. + Attributes: + logcontext(LoggingContext) : the log context for this request """ def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 + # we can't yet create the logcontext, as we don't know the method. + self.logcontext = None + global _next_request_seq self.request_seq = _next_request_seq _next_request_seq += 1 + # whether an asynchronous request handler has called processing() + self._is_processing = False + + # the time when the asynchronous request handler completed its processing + self._processing_finished_time = None + + # what time we finished sending the response to the client (or the connection + # dropped) + self.finish_time = None + def __repr__(self): # We overwrite this so that we don't log ``access_token`` return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( @@ -74,11 +91,116 @@ class SynapseRequest(Request): return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def render(self, resrc): + # this is called once a Resource has been found to serve the request; in our + # case the Resource in question will normally be a JsonResource. + + # create a LogContext for this request + request_id = self.get_request_id() + logcontext = self.logcontext = LoggingContext(request_id) + logcontext.request = request_id + # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) - return Request.render(self, resrc) + + with PreserveLoggingContext(self.logcontext): + # we start the request metrics timer here with an initial stab + # at the servlet name. For most requests that name will be + # JsonResource (or a subclass), and JsonResource._async_render + # will update it once it picks a servlet. + servlet_name = resrc.__class__.__name__ + self._started_processing(servlet_name) + + Request.render(self, resrc) + + # record the arrival of the request *after* + # dispatching to the handler, so that the handler + # can update the servlet name in the request + # metrics + requests_counter.labels(self.method, + self.request_metrics.name).inc() + + @contextlib.contextmanager + def processing(self): + """Record the fact that we are processing this request. + + Returns a context manager; the correct way to use this is: + + @defer.inlineCallbacks + def handle_request(request): + with request.processing("FooServlet"): + yield really_handle_the_request() + + Once the context manager is closed, the completion of the request will be logged, + and the various metrics will be updated. + """ + if self._is_processing: + raise RuntimeError("Request is already processing") + self._is_processing = True + + try: + yield + except Exception: + # this should already have been caught, and sent back to the client as a 500. + logger.exception("Asynchronous messge handler raised an uncaught exception") + finally: + # the request handler has finished its work and either sent the whole response + # back, or handed over responsibility to a Producer. + + self._processing_finished_time = time.time() + self._is_processing = False + + # if we've already sent the response, log it now; otherwise, we wait for the + # response to be sent. + if self.finish_time is not None: + self._finished_processing() + + def finish(self): + """Called when all response data has been written to this Request. + + Overrides twisted.web.server.Request.finish to record the finish time and do + logging. + """ + self.finish_time = time.time() + Request.finish(self) + if not self._is_processing: + with PreserveLoggingContext(self.logcontext): + self._finished_processing() + + def connectionLost(self, reason): + """Called when the client connection is closed before the response is written. + + Overrides twisted.web.server.Request.connectionLost to record the finish time and + do logging. + """ + self.finish_time = time.time() + Request.connectionLost(self, reason) + + # we only get here if the connection to the client drops before we send + # the response. + # + # It's useful to log it here so that we can get an idea of when + # the client disconnects. + with PreserveLoggingContext(self.logcontext): + logger.warn( + "Error processing request %r: %s %s", self, reason.type, reason.value, + ) + + if not self._is_processing: + self._finished_processing() def _started_processing(self, servlet_name): + """Record the fact that we are processing this request. + + This will log the request's arrival. Once the request completes, + be sure to call finished_processing. + + Args: + servlet_name (str): the name of the servlet which will be + processing this request. This is used in the metrics. + + It is possible to update this afterwards by updating + self.request_metrics.name. + """ self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( @@ -94,18 +216,32 @@ class SynapseRequest(Request): ) def _finished_processing(self): - try: - context = LoggingContext.current_context() - usage = context.get_resource_usage() - except Exception: - usage = ContextResourceUsage() + """Log the completion of this request and update the metrics + """ + + if self.logcontext is None: + # this can happen if the connection closed before we read the + # headers (so render was never called). In that case we'll already + # have logged a warning, so just bail out. + return + + usage = self.logcontext.get_resource_usage() + + if self._processing_finished_time is None: + # we completed the request without anything calling processing() + self._processing_finished_time = time.time() - end_time = time.time() + # the time between receiving the request and the request handler finishing + processing_time = self._processing_finished_time - self.start_time + + # the time between the request handler finishing and the response being sent + # to the client (nb may be negative) + response_send_time = self.finish_time - self._processing_finished_time # need to decode as it could be raw utf-8 bytes # from a IDN servname in an auth header authenticated_entity = self.authenticated_entity - if authenticated_entity is not None: + if authenticated_entity is not None and isinstance(authenticated_entity, bytes): authenticated_entity = authenticated_entity.decode("utf-8", "replace") # ...or could be raw utf-8 bytes in the User-Agent header. @@ -116,22 +252,31 @@ class SynapseRequest(Request): user_agent = self.get_user_agent() if user_agent is not None: user_agent = user_agent.decode("utf-8", "replace") + else: + user_agent = "-" + + code = str(self.code) + if not self.finished: + # we didn't send the full response before we gave up (presumably because + # the connection dropped) + code += "!" self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" + " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", self.getClientIP(), self.site.site_tag, authenticated_entity, - end_time - self.start_time, + processing_time, + response_send_time, usage.ru_utime, usage.ru_stime, usage.db_sched_duration_sec, usage.db_txn_duration_sec, int(usage.db_txn_count), self.sentLength, - self.code, + code, self.method, self.get_redacted_uri(), self.clientproto, @@ -140,38 +285,10 @@ class SynapseRequest(Request): ) try: - self.request_metrics.stop(end_time, self) + self.request_metrics.stop(self.finish_time, self) except Exception as e: logger.warn("Failed to stop metrics: %r", e) - @contextlib.contextmanager - def processing(self, servlet_name): - """Record the fact that we are processing this request. - - Returns a context manager; the correct way to use this is: - - @defer.inlineCallbacks - def handle_request(request): - with request.processing("FooServlet"): - yield really_handle_the_request() - - This will log the request's arrival. Once the context manager is - closed, the completion of the request will be logged, and the various - metrics will be updated. - - Args: - servlet_name (str): the name of the servlet which will be - processing this request. This is used in the metrics. - - It is possible to update this afterwards by updating - self.request_metrics.servlet_name. - """ - # TODO: we should probably just move this into render() and finish(), - # to save having to call a separate method. - self._started_processing(servlet_name) - yield - self._finished_processing() - class XForwardedForRequest(SynapseRequest): def __init__(self, *args, **kw): @@ -217,7 +334,7 @@ class SynapseSite(Site): proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) - self.server_version_string = server_version_string + self.server_version_string = server_version_string.encode('ascii') def log(self, request): pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index a9158fc066..550f8443f7 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -174,6 +174,19 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions events_processed_counter = Counter("synapse_federation_client_events_processed", "") +event_processing_loop_counter = Counter( + "synapse_event_processing_loop_count", + "Event processing loop iterations", + ["name"], +) + +event_processing_loop_room_count = Counter( + "synapse_event_processing_loop_room_count", + "Rooms seen per event processing loop iteration", + ["name"], +) + + # Used to track where various components have processed in the event stream, # e.g. federation sending, appservice sending, etc. event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"]) diff --git a/synapse/notifier.py b/synapse/notifier.py index e650c3e494..82f391481c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge from synapse.types import StreamToken -from synapse.util.async import ( +from synapse.util.async_helpers import ( DeferredTimeoutError, ObservableDeferred, add_timeout_to_deferred, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1d14d3639c..8f9a76147f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache from synapse.util.caches.descriptors import cached diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 9d601208fd..bfa6df7b68 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -35,7 +35,7 @@ from synapse.push.presentable_names import ( name_from_member_event, ) from synapse.types import UserID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 36bb5bbc65..9f7d5ef217 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push.pusher import PusherFactory from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -122,8 +123,14 @@ class PusherPool: p['app_id'], p['pushkey'], p['user_name'], ) - @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): + run_as_background_process( + "on_new_notifications", + self._on_new_notifications, min_stream_id, max_stream_id, + ) + + @defer.inlineCallbacks + def _on_new_notifications(self, min_stream_id, max_stream_id): try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id @@ -147,8 +154,14 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + run_as_background_process( + "on_new_receipts", + self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, + ) + + @defer.inlineCallbacks + def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 589ee94c66..19f214281e 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.http.server import JsonResource -from synapse.replication.http import membership, send_event +from synapse.replication.http import federation, membership, send_event REPLICATION_PREFIX = "/_synapse/replication" @@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) membership.register_servlets(hs, self) + federation.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py new file mode 100644 index 0000000000..5e5376cf58 --- /dev/null +++ b/synapse/replication/http/_base.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +import re + +from six.moves import urllib + +from twisted.internet import defer + +from synapse.api.errors import CodeMessageException, HttpResponseException +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class ReplicationEndpoint(object): + """Helper base class for defining new replication HTTP endpoints. + + This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` + (with an `/:txn_id` prefix for cached requests.), where NAME is a name, + PATH_ARGS are a tuple of parameters to be encoded in the URL. + + For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, + with `CACHE` set to true then this generates an endpoint: + + /_synapse/replication/send_event/:event_id/:txn_id + + For POST/PUT requests the payload is serialized to json and sent as the + body, while for GET requests the payload is added as query parameters. See + `_serialize_payload` for details. + + Incoming requests are handled by overriding `_handle_request`. Servers + must call `register` to register the path with the HTTP server. + + Requests can be sent by calling the client returned by `make_client`. + + Attributes: + NAME (str): A name for the endpoint, added to the path as well as used + in logging and metrics. + PATH_ARGS (tuple[str]): A list of parameters to be added to the path. + Adding parameters to the path (rather than payload) can make it + easier to follow along in the log files. + METHOD (str): The method of the HTTP request, defaults to POST. Can be + one of POST, PUT or GET. If GET then the payload is sent as query + parameters rather than a JSON body. + CACHE (bool): Whether server should cache the result of the request/ + If true then transparently adds a txn_id to all requests, and + `_handle_request` must return a Deferred. + RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504 + is received. + """ + + __metaclass__ = abc.ABCMeta + + NAME = abc.abstractproperty() + PATH_ARGS = abc.abstractproperty() + + METHOD = "POST" + CACHE = True + RETRY_ON_TIMEOUT = True + + def __init__(self, hs): + if self.CACHE: + self.response_cache = ResponseCache( + hs, "repl." + self.NAME, + timeout_ms=30 * 60 * 1000, + ) + + assert self.METHOD in ("PUT", "POST", "GET") + + @abc.abstractmethod + def _serialize_payload(**kwargs): + """Static method that is called when creating a request. + + Concrete implementations should have explicit parameters (rather than + kwargs) so that an appropriate exception is raised if the client is + called with unexpected parameters. All PATH_ARGS must appear in + argument list. + + Returns: + Deferred[dict]|dict: If POST/PUT request then dictionary must be + JSON serialisable, otherwise must be appropriate for adding as + query args. + """ + return {} + + @abc.abstractmethod + def _handle_request(self, request, **kwargs): + """Handle incoming request. + + This is called with the request object and PATH_ARGS. + + Returns: + Deferred[dict]: A JSON serialisable dict to be used as response + body of request. + """ + pass + + @classmethod + def make_client(cls, hs): + """Create a client that makes requests. + + Returns a callable that accepts the same parameters as `_serialize_payload`. + """ + clock = hs.get_clock() + host = hs.config.worker_replication_host + port = hs.config.worker_replication_http_port + + client = hs.get_simple_http_client() + + @defer.inlineCallbacks + def send_request(**kwargs): + data = yield cls._serialize_payload(**kwargs) + + url_args = [urllib.parse.quote(kwargs[name]) for name in cls.PATH_ARGS] + + if cls.CACHE: + txn_id = random_string(10) + url_args.append(txn_id) + + if cls.METHOD == "POST": + request_func = client.post_json_get_json + elif cls.METHOD == "PUT": + request_func = client.put_json + elif cls.METHOD == "GET": + request_func = client.get_json + else: + # We have already asserted in the constructor that a + # compatible was picked, but lets be paranoid. + raise Exception( + "Unknown METHOD on %s replication endpoint" % (cls.NAME,) + ) + + uri = "http://%s:%s/_synapse/replication/%s/%s" % ( + host, port, cls.NAME, "/".join(url_args) + ) + + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed on + # the master, and so whether we should clean up or not. + while True: + try: + result = yield request_func(uri, data) + break + except CodeMessageException as e: + if e.code != 504 or not cls.RETRY_ON_TIMEOUT: + raise + + logger.warn("%s request timed out", cls.NAME) + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + yield clock.sleep(1) + except HttpResponseException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the master process that we should send to the client. (And + # importantly, not stack traces everywhere) + raise e.to_synapse_error() + + defer.returnValue(result) + + return send_request + + def register(self, http_server): + """Called by the server to register this as a handler to the + appropriate path. + """ + + url_args = list(self.PATH_ARGS) + handler = self._handle_request + method = self.METHOD + + if self.CACHE: + handler = self._cached_handler + url_args.append("txn_id") + + args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) + pattern = re.compile("^/_synapse/replication/%s/%s$" % ( + self.NAME, + args + )) + + http_server.register_paths(method, [pattern], handler) + + def _cached_handler(self, request, txn_id, **kwargs): + """Called on new incoming requests when caching is enabled. Checks + if there is a cached response for the request and returns that, + otherwise calls `_handle_request` and caches its response. + """ + # We just use the txn_id here, but we probably also want to use the + # other PATH_ARGS as well. + + assert self.CACHE + + return self.response_cache.wrap( + txn_id, + self._handle_request, + request, **kwargs + ) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py new file mode 100644 index 0000000000..64a79da162 --- /dev/null +++ b/synapse/replication/http/federation.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): + """Handles events newly received from federation, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/fed_send_events/:txn_id + + { + "events": [{ + "event": { .. serialized event .. }, + "internal_metadata": { .. serialized internal_metadata .. }, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + }], + "backfilled": false + """ + + NAME = "fed_send_events" + PATH_ARGS = () + + def __init__(self, hs): + super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) + + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.federation_handler = hs.get_handlers().federation_handler + + @staticmethod + @defer.inlineCallbacks + def _serialize_payload(store, event_and_contexts, backfilled): + """ + Args: + store + event_and_contexts (list[tuple[FrozenEvent, EventContext]]) + backfilled (bool): Whether or not the events are the result of + backfilling + """ + event_payloads = [] + for event, context in event_and_contexts: + serialized_context = yield context.serialize(event, store) + + event_payloads.append({ + "event": event.get_pdu_json(), + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + }) + + payload = { + "events": event_payloads, + "backfilled": backfilled, + } + + defer.returnValue(payload) + + @defer.inlineCallbacks + def _handle_request(self, request): + with Measure(self.clock, "repl_fed_send_events_parse"): + content = parse_json_object_from_request(request) + + backfilled = content["backfilled"] + + event_payloads = content["events"] + + event_and_contexts = [] + for event_payload in event_payloads: + event_dict = event_payload["event"] + internal_metadata = event_payload["internal_metadata"] + rejected_reason = event_payload["rejected_reason"] + event = FrozenEvent(event_dict, internal_metadata, rejected_reason) + + context = yield EventContext.deserialize( + self.store, event_payload["context"], + ) + + event_and_contexts.append((event, context)) + + logger.info( + "Got %d events from federation", + len(event_and_contexts), + ) + + yield self.federation_handler.persist_events_and_notify( + event_and_contexts, backfilled, + ) + + defer.returnValue((200, {})) + + +class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): + """Handles EDUs newly received from federation, including persisting and + notifying. + + Request format: + + POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id + + { + "origin": ..., + "content: { ... } + } + """ + + NAME = "fed_send_edu" + PATH_ARGS = ("edu_type",) + + def __init__(self, hs): + super(ReplicationFederationSendEduRestServlet, self).__init__(hs) + + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.registry = hs.get_federation_registry() + + @staticmethod + def _serialize_payload(edu_type, origin, content): + return { + "origin": origin, + "content": content, + } + + @defer.inlineCallbacks + def _handle_request(self, request, edu_type): + with Measure(self.clock, "repl_fed_send_edu_parse"): + content = parse_json_object_from_request(request) + + origin = content["origin"] + edu_content = content["content"] + + logger.info( + "Got %r edu from %s", + edu_type, origin, + ) + + result = yield self.registry.on_edu(edu_type, origin, edu_content) + + defer.returnValue((200, result)) + + +class ReplicationGetQueryRestServlet(ReplicationEndpoint): + """Handle responding to queries from federation. + + Request format: + + POST /_synapse/replication/fed_query/:query_type + + { + "args": { ... } + } + """ + + NAME = "fed_query" + PATH_ARGS = ("query_type",) + + # This is a query, so let's not bother caching + CACHE = False + + def __init__(self, hs): + super(ReplicationGetQueryRestServlet, self).__init__(hs) + + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.registry = hs.get_federation_registry() + + @staticmethod + def _serialize_payload(query_type, args): + """ + Args: + query_type (str) + args (dict): The arguments received for the given query type + """ + return { + "args": args, + } + + @defer.inlineCallbacks + def _handle_request(self, request, query_type): + with Measure(self.clock, "repl_fed_query_parse"): + content = parse_json_object_from_request(request) + + args = content["args"] + + logger.info( + "Got %r query", + query_type, + ) + + result = yield self.registry.on_query(query_type, args) + + defer.returnValue((200, result)) + + +class ReplicationCleanRoomRestServlet(ReplicationEndpoint): + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Request format: + + POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id + + {} + """ + + NAME = "fed_cleanup_room" + PATH_ARGS = ("room_id",) + + def __init__(self, hs): + super(ReplicationCleanRoomRestServlet, self).__init__(hs) + + self.store = hs.get_datastore() + + @staticmethod + def _serialize_payload(room_id, args): + """ + Args: + room_id (str) + """ + return {} + + @defer.inlineCallbacks + def _handle_request(self, request, room_id): + yield self.store.clean_room_for_join(room_id) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReplicationFederationSendEventsRestServlet(hs).register(http_server) + ReplicationFederationSendEduRestServlet(hs).register(http_server) + ReplicationGetQueryRestServlet(hs).register(http_server) + ReplicationCleanRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 7a3cfb159c..e58bebf12a 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,182 +14,63 @@ # limitations under the License. import logging -import re from twisted.internet import defer -from synapse.api.errors import HttpResponseException -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def remote_join(client, host, port, requester, remote_room_hosts, - room_id, user_id, content): - """Ask the master to do a remote join for the given user to the given room +class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): + """Does a remote join for the given user to the given room - Args: - client (SimpleHttpClient) - host (str): host of master - port (int): port on master listening for HTTP replication - requester (Requester) - remote_room_hosts (list[str]): Servers to try and join via - room_id (str) - user_id (str) - content (dict): The event content to use for the join event + Request format: - Returns: - Deferred - """ - uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port) - - payload = { - "requester": requester.serialize(), - "remote_room_hosts": remote_room_hosts, - "room_id": room_id, - "user_id": user_id, - "content": content, - } - - try: - result = yield client.post_json_get_json(uri, payload) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And - # importantly, not stack traces everywhere) - raise e.to_synapse_error() - defer.returnValue(result) - - -@defer.inlineCallbacks -def remote_reject_invite(client, host, port, requester, remote_room_hosts, - room_id, user_id): - """Ask master to reject the invite for the user and room. - - Args: - client (SimpleHttpClient) - host (str): host of master - port (int): port on master listening for HTTP replication - requester (Requester) - remote_room_hosts (list[str]): Servers to try and reject via - room_id (str) - user_id (str) - - Returns: - Deferred - """ - uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port) - - payload = { - "requester": requester.serialize(), - "remote_room_hosts": remote_room_hosts, - "room_id": room_id, - "user_id": user_id, - } - - try: - result = yield client.post_json_get_json(uri, payload) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And - # importantly, not stack traces everywhere) - raise e.to_synapse_error() - defer.returnValue(result) - - -@defer.inlineCallbacks -def get_or_register_3pid_guest(client, host, port, requester, - medium, address, inviter_user_id): - """Ask the master to get/create a guest account for given 3PID. - - Args: - client (SimpleHttpClient) - host (str): host of master - port (int): port on master listening for HTTP replication - requester (Requester) - medium (str) - address (str) - inviter_user_id (str): The user ID who is trying to invite the - 3PID - - Returns: - Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the - 3PID guest account. - """ + POST /_synapse/replication/remote_join/:room_id/:user_id - uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port) - - payload = { - "requester": requester.serialize(), - "medium": medium, - "address": address, - "inviter_user_id": inviter_user_id, - } - - try: - result = yield client.post_json_get_json(uri, payload) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And - # importantly, not stack traces everywhere) - raise e.to_synapse_error() - defer.returnValue(result) - - -@defer.inlineCallbacks -def notify_user_membership_change(client, host, port, user_id, room_id, change): - """Notify master that a user has joined or left the room - - Args: - client (SimpleHttpClient) - host (str): host of master - port (int): port on master listening for HTTP replication. - user_id (str) - room_id (str) - change (str): Either "join" or "left" - - Returns: - Deferred + { + "requester": ..., + "remote_room_hosts": [...], + "content": { ... } + } """ - assert change in ("joined", "left") - - uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change) - - payload = { - "user_id": user_id, - "room_id": room_id, - } - - try: - result = yield client.post_json_get_json(uri, payload) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And - # importantly, not stack traces everywhere) - raise e.to_synapse_error() - defer.returnValue(result) - -class ReplicationRemoteJoinRestServlet(RestServlet): - PATTERNS = [re.compile("^/_synapse/replication/remote_join$")] + NAME = "remote_join" + PATH_ARGS = ("room_id", "user_id",) def __init__(self, hs): - super(ReplicationRemoteJoinRestServlet, self).__init__() + super(ReplicationRemoteJoinRestServlet, self).__init__(hs) self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + @staticmethod + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, + content): + """ + Args: + requester(Requester) + room_id (str) + user_id (str) + remote_room_hosts (list[str]): Servers to try and join via + content(dict): The event content to use for the join event + """ + return { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "content": content, + } + @defer.inlineCallbacks - def on_POST(self, request): + def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] - room_id = content["room_id"] - user_id = content["user_id"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -212,23 +93,48 @@ class ReplicationRemoteJoinRestServlet(RestServlet): defer.returnValue((200, {})) -class ReplicationRemoteRejectInviteRestServlet(RestServlet): - PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")] +class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): + """Rejects the invite for the user and room. + + Request format: + + POST /_synapse/replication/remote_reject_invite/:room_id/:user_id + + { + "requester": ..., + "remote_room_hosts": [...], + } + """ + + NAME = "remote_reject_invite" + PATH_ARGS = ("room_id", "user_id",) def __init__(self, hs): - super(ReplicationRemoteRejectInviteRestServlet, self).__init__() + super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + @staticmethod + def _serialize_payload(requester, room_id, user_id, remote_room_hosts): + """ + Args: + requester(Requester) + room_id (str) + user_id (str) + remote_room_hosts (list[str]): Servers to try and reject via + """ + return { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + } + @defer.inlineCallbacks - def on_POST(self, request): + def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] - room_id = content["room_id"] - user_id = content["user_id"] requester = Requester.deserialize(self.store, content["requester"]) @@ -264,18 +170,50 @@ class ReplicationRemoteRejectInviteRestServlet(RestServlet): defer.returnValue((200, ret)) -class ReplicationRegister3PIDGuestRestServlet(RestServlet): - PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")] +class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint): + """Gets/creates a guest account for given 3PID. + + Request format: + + POST /_synapse/replication/get_or_register_3pid_guest/ + + { + "requester": ..., + "medium": ..., + "address": ..., + "inviter_user_id": ... + } + """ + + NAME = "get_or_register_3pid_guest" + PATH_ARGS = () def __init__(self, hs): - super(ReplicationRegister3PIDGuestRestServlet, self).__init__() + super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs) self.registeration_handler = hs.get_handlers().registration_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + @staticmethod + def _serialize_payload(requester, medium, address, inviter_user_id): + """ + Args: + requester(Requester) + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + """ + return { + "requester": requester.serialize(), + "medium": medium, + "address": address, + "inviter_user_id": inviter_user_id, + } + @defer.inlineCallbacks - def on_POST(self, request): + def _handle_request(self, request): content = parse_json_object_from_request(request) medium = content["medium"] @@ -296,23 +234,41 @@ class ReplicationRegister3PIDGuestRestServlet(RestServlet): defer.returnValue((200, ret)) -class ReplicationUserJoinedLeftRoomRestServlet(RestServlet): - PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")] +class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): + """Notifies that a user has joined or left the room + + Request format: + + POST /_synapse/replication/membership_change/:room_id/:user_id/:change + + {} + """ + + NAME = "membership_change" + PATH_ARGS = ("room_id", "user_id", "change") + CACHE = False # No point caching as should return instantly. def __init__(self, hs): - super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__() + super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs) self.registeration_handler = hs.get_handlers().registration_handler self.store = hs.get_datastore() self.clock = hs.get_clock() self.distributor = hs.get_distributor() - def on_POST(self, request, change): - content = parse_json_object_from_request(request) + @staticmethod + def _serialize_payload(room_id, user_id, change): + """ + Args: + room_id (str) + user_id (str) + change (str): Either "joined" or "left" + """ + assert change in ("joined", "left",) - user_id = content["user_id"] - room_id = content["room_id"] + return {} + def _handle_request(self, request, room_id, user_id, change): logger.info("user membership change: %s in %s", user_id, room_id) user = UserID.from_string(user_id) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index d3509dc288..5b52c91650 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -14,86 +14,26 @@ # limitations under the License. import logging -import re from twisted.internet import defer -from synapse.api.errors import CodeMessageException, HttpResponseException from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID -from synapse.util.caches.response_cache import ResponseCache from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def send_event_to_master(clock, store, client, host, port, requester, event, context, - ratelimit, extra_users): - """Send event to be handled on the master - - Args: - clock (synapse.util.Clock) - store (DataStore) - client (SimpleHttpClient) - host (str): host of master - port (int): port on master listening for HTTP replication - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event - """ - uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( - host, port, event.event_id, - ) - - serialized_context = yield context.serialize(event, store) - - payload = { - "event": event.get_pdu_json(), - "internal_metadata": event.internal_metadata.get_dict(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - "requester": requester.serialize(), - "ratelimit": ratelimit, - "extra_users": [u.to_string() for u in extra_users], - } - - try: - # We keep retrying the same request for timeouts. This is so that we - # have a good idea that the request has either succeeded or failed on - # the master, and so whether we should clean up or not. - while True: - try: - result = yield client.put_json(uri, payload) - break - except CodeMessageException as e: - if e.code != 504: - raise - - logger.warn("send_event request timed out") - - # If we timed out we probably don't need to worry about backing - # off too much, but lets just wait a little anyway. - yield clock.sleep(1) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And - # importantly, not stack traces everywhere) - raise e.to_synapse_error() - defer.returnValue(result) - - -class ReplicationSendEventRestServlet(RestServlet): +class ReplicationSendEventRestServlet(ReplicationEndpoint): """Handles events newly created on workers, including persisting and notifying. The API looks like: - POST /_synapse/replication/send_event/:event_id + POST /_synapse/replication/send_event/:event_id/:txn_id { "event": { .. serialized event .. }, @@ -105,27 +45,47 @@ class ReplicationSendEventRestServlet(RestServlet): "extra_users": [], } """ - PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")] + NAME = "send_event" + PATH_ARGS = ("event_id",) def __init__(self, hs): - super(ReplicationSendEventRestServlet, self).__init__() + super(ReplicationSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() - # The responses are tiny, so we may as well cache them for a while - self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) + @staticmethod + @defer.inlineCallbacks + def _serialize_payload(event_id, store, event, context, requester, + ratelimit, extra_users): + """ + Args: + event_id (str) + store (DataStore) + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(UserID)): Any extra users to notify about event + """ + + serialized_context = yield context.serialize(event, store) + + payload = { + "event": event.get_pdu_json(), + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } - def on_PUT(self, request, event_id): - return self.response_cache.wrap( - event_id, - self._handle_request, - request - ) + defer.returnValue(payload) @defer.inlineCallbacks - def _handle_request(self, request): + def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index bdb5eee4af..4830c68f35 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -44,8 +44,8 @@ class SlavedEventStore(EventFederationWorkerStore, RoomMemberWorkerStore, EventPushActionsWorkerStore, StreamWorkerStore, - EventsWorkerStore, StateGroupWorkerStore, + EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 9c9a5eadd9..3527beb3c9 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,19 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore from synapse.storage.transactions import TransactionStore from ._base import BaseSlavedStore -class TransactionStore(BaseSlavedStore): - get_destination_retry_timings = TransactionStore.__dict__[ - "get_destination_retry_timings" - ] - _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ - set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__ - _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__ - - prep_send_transaction = DataStore.prep_send_transaction.__func__ - delivered_txn = DataStore.delivered_txn.__func__ +class SlavedTransactionStore(TransactionStore, BaseSlavedStore): + pass diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 970e94313e..cbe9645817 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -107,7 +107,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ logger.info("Received rdata %s -> %s", stream_name, token) - self.store.process_replication_rows(stream_name, token, rows) + return self.store.process_replication_rows(stream_name, token, rows) def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes @@ -115,7 +115,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ - self.store.process_replication_rows(stream_name, token, []) + return self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f3908df642..327556f6a1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -59,6 +59,12 @@ class Command(object): """ return self.data + def get_logcontext_id(self): + """Get a suitable string for the logcontext when processing this command""" + + # by default, we just use the command name. + return self.NAME + class ServerCommand(Command): """Sent by the server on new connection and includes the server_name. @@ -116,6 +122,9 @@ class RdataCommand(Command): _json_encoder.encode(self.row), )) + def get_logcontext_id(self): + return "RDATA-" + self.stream_name + class PositionCommand(Command): """Sent by the client to tell the client the stream postition without @@ -190,6 +199,9 @@ class ReplicateCommand(Command): def to_line(self): return " ".join((self.stream_name, str(self.token),)) + def get_logcontext_id(self): + return "REPLICATE-" + self.stream_name + class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dec5ac0913..74e892c104 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from .commands import ( @@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_<CMD_NAME> function try: - getattr(self, "on_%s" % (cmd_name,))(cmd) + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), + getattr(self, "on_%s" % (cmd_name,)), + cmd, + ) except Exception: logger.exception("[%s] Failed to handle line: %r", self.id(), line) @@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.name = cmd.data def on_USER_SYNC(self, cmd): - self.streamer.on_user_sync( + return self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, ) @@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in iterkeys(self.streamer.streams_by_name): - self.subscribe_to_stream(stream, token) + deferreds = [ + run_in_background( + self.subscribe_to_stream, + stream, token, + ) + for stream in iterkeys(self.streamer.streams_by_name) + ] + + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) else: - self.subscribe_to_stream(stream_name, token) + return self.subscribe_to_stream(stream_name, token) def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) + return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + return self.streamer.on_remove_pusher( + cmd.app_id, cmd.push_key, cmd.user_id, + ) def on_INVALIDATE_CACHE(self, cmd): - self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): - self.streamer.on_user_ip( + return self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.last_seen, ) @@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - - self.handler.on_rdata(stream_name, cmd.token, rows) + return self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): - self.handler.on_position(cmd.stream_name, cmd.token) + return self.handler.on_position(cmd.stream_name, cmd.token) def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + return self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 00b1b3066e..48c17f1b6d 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,7 +17,7 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ class HttpTransactionCache(object): str: A transaction key """ token = self.auth.get_access_token_from_request(request) - return request.path + "/" + token + return request.path.decode('utf8') + "/" + token def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 80d625eecc..ad536ab570 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - yield self._deactivate_account_handler.deactivate_account( + result = yield self._deactivate_account_handler.deactivate_account( target_user_id, erase, ) - defer.returnValue((200, {})) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class ShutdownRoomRestServlet(ClientV1RestServlet): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a14f0c807e..b5a6d6aebf 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except Exception: raise SynapseError(400, "Unable to parse state") - yield self.presence_handler.set_state(user, state) + if self.hs.config.use_presence: + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fa5989e74e..fcc1091760 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -34,7 +34,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID +from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from .base import ClientV1RestServlet, client_path_patterns @@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) requester = yield self.auth.get_user_by_req(request) - events = yield self.message_handler.get_state_events( + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = StreamToken.from_string(at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = yield handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), + at_token=at_token, + types=[(EventTypes.Member, None)], ) chunk = [] for event in events: - if event["type"] != EventTypes.Member: + if ( + (membership and event['content'].get("membership") != membership) or + (not_membership and event['content'].get("membership") == not_membership) + ): continue chunk.append(event) @@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format class JoinedRoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py index 3439c3c6d4..5e99cffbcb 100644 --- a/synapse/rest/client/v1_only/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet): login_type = register_json["type"] is_application_server = login_type == LoginType.APPLICATION_SERVICE - is_using_shared_secret = login_type == LoginType.SHARED_SECRET - can_register = ( self.enable_registration or is_application_server - or is_using_shared_secret ) if not can_register: raise SynapseError(403, "Registration has been disabled") @@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet): LoginType.PASSWORD: self._do_password, LoginType.EMAIL_IDENTITY: self._do_email_identity, LoginType.APPLICATION_SERVICE: self._do_app_service, - LoginType.SHARED_SECRET: self._do_shared_secret, } session_info = self._get_session_info(request, session) @@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, }) - @defer.inlineCallbacks - def _do_shared_secret(self, request, register_json, session): - assert_params_in_dict(register_json, ["mac", "user", "password"]) - - if not self.hs.config.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") - - user = register_json["user"].encode("utf-8") - password = register_json["password"].encode("utf-8") - admin = register_json.get("admin", None) - - # Its important to check as we use null bytes as HMAC field separators - if b"\x00" in user: - raise SynapseError(400, "Invalid user") - if b"\x00" in password: - raise SynapseError(400, "Invalid password") - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got_mac = str(register_json["mac"]) - - want_mac = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), - digestmod=sha1, - ) - want_mac.update(user) - want_mac.update(b"\x00") - want_mac.update(password) - want_mac.update(b"\x00") - want_mac.update(b"admin" if admin else b"notadmin") - want_mac = want_mac.hexdigest() - - if compare_digest(want_mac, got_mac): - handler = self.handlers.registration_handler - user_id, token = yield handler.register( - localpart=user.lower(), - password=password, - admin=bool(admin), - ) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - else: - raise SynapseError( - 403, "HMAC incorrect", - ) - class CreateUserRestServlet(ClientV1RestServlet): """Handles user creation via a server-to-server interface diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eeae466d82..372648cafd 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet): yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), ) - yield self._deactivate_account_handler.deactivate_account( + result = yield self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, ) - defer.returnValue((200, {})) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - yield self.auth_handler.delete_threepid( + ret = yield self.auth_handler.delete_threepid( user_id, body['medium'], body['address'] ) except Exception: @@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet): logger.exception("Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid") - defer.returnValue((200, {})) + if ret: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class WhoamiRestServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8aa06faf23..1275baa1ba 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary return result diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 6ac2987b98..29e62bfcdd 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -27,11 +27,22 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. "r0.0.1", "r0.1.0", "r0.2.0", "r0.3.0", - ] + ], + # as per MSC1497: + "unstable_features": { + "m.lazy_load_members": True, + } }) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 147ff7d79b..7362e1858d 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -140,7 +140,7 @@ class ConsentResource(Resource): version = parse_string(request, "v", default=self._default_consent_version) username = parse_string(request, "u", required=True) - userhmac = parse_string(request, "h", required=True) + userhmac = parse_string(request, "h", required=True, encoding=None) self._check_hash(username, userhmac) @@ -175,7 +175,7 @@ class ConsentResource(Resource): """ version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - userhmac = parse_string(request, "h", required=True) + userhmac = parse_string(request, "h", required=True, encoding=None) self._check_hash(username, userhmac) @@ -210,9 +210,18 @@ class ConsentResource(Resource): finish_request(request) def _check_hash(self, userid, userhmac): + """ + Args: + userid (unicode): + userhmac (bytes): + + Raises: + SynapseError if the hash doesn't match + + """ want_mac = hmac.new( key=self._hmac_secret, - msg=userid, + msg=userid.encode('utf-8'), digestmod=sha256, ).hexdigest() diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py new file mode 100644 index 0000000000..d6605b6027 --- /dev/null +++ b/synapse/rest/media/v1/config_resource.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Will Hunt <will@half-shot.uk> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from twisted.internet import defer +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + +from synapse.http.server import respond_with_json, wrap_json_request_handler + + +class MediaConfigResource(Resource): + isLeaf = True + + def __init__(self, hs): + Resource.__init__(self) + config = hs.get_config() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.limits_dict = { + "m.upload.size": config.max_upload_size, + } + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @wrap_json_request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + yield self.auth.get_user_by_req(request) + respond_with_json(request, 200, self.limits_dict) + + def render_OPTIONS(self, request): + respond_with_json(request, 200, {}, send_cors=True) + return NOT_DONE_YET diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8fb413d825..241c972070 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -36,12 +36,13 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string from ._base import FileInfo, respond_404, respond_with_responder +from .config_resource import MediaConfigResource from .download_resource import DownloadResource from .filepath import MediaFilePaths from .identicon_resource import IdenticonResource @@ -754,7 +755,6 @@ class MediaRepositoryResource(Resource): Resource.__init__(self) media_repo = hs.get_media_repository() - self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) self.putChild("thumbnail", ThumbnailResource( @@ -765,3 +765,4 @@ class MediaRepositoryResource(Resource): self.putChild("preview_url", PreviewUrlResource( hs, media_repo, media_repo.media_storage, )) + self.putChild("config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 27aa0def2f..778ef97337 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,7 +42,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 9b22d204a6..c1240e1963 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -55,7 +55,7 @@ class UploadResource(Resource): requester = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader("Content-Length") + content_length = request.getHeader(b"Content-Length").decode('ascii') if content_length is None: raise SynapseError( msg="Request must specify a Content-Length", code=400 @@ -66,10 +66,10 @@ class UploadResource(Resource): code=413, ) - upload_name = parse_string(request, "filename") + upload_name = parse_string(request, b"filename", encoding=None) if upload_name: try: - upload_name = upload_name.decode('UTF-8') + upload_name = upload_name.decode('utf8') except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), @@ -78,8 +78,8 @@ class UploadResource(Resource): headers = request.requestHeaders - if headers.hasHeader("Content-Type"): - media_type = headers.getRawHeaders(b"Content-Type")[0] + if headers.hasHeader(b"Content-Type"): + media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii') else: raise SynapseError( msg="Upload request missing 'Content-Type'", diff --git a/synapse/secrets.py b/synapse/secrets.py index f05e9ea535..f6280f951c 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -38,4 +38,4 @@ else: return os.urandom(nbytes) def token_hex(self, nbytes=32): - return binascii.hexlify(self.token_bytes(nbytes)) + return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii') diff --git a/synapse/server.py b/synapse/server.py index 140be9ebe8..26228d8c72 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, + ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.transaction_queue import TransactionQueue @@ -423,7 +424,10 @@ class HomeServer(object): return RoomMemberMasterHandler(self) def build_federation_registry(self): - return FederationHandlerRegistry() + if self.config.worker_app: + return ReplicationFederationHandlerRegistry(self) + else: + return FederationHandlerRegistry() def build_server_notices_manager(self): if self.config.worker_app: diff --git a/synapse/state.py b/synapse/state.py index e1092b97a9..0f2bedb694 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -28,8 +28,8 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext -from synapse.util.async import Linearizer -from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 134e4a80f1..23b4a8d76d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -39,6 +39,7 @@ from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore from .media_repository import MediaRepositoryStore +from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore @@ -87,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore, UserDirectoryStore, GroupServerStore, UserErasureStore, + MonthlyActiveUsersStore, ): def __init__(self, db_conn, hs): @@ -94,7 +96,6 @@ class DataStore(RoomMemberStore, RoomStore, self._clock = hs.get_clock() self.database_engine = hs.database_engine - self.db_conn = db_conn self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")] @@ -267,31 +268,6 @@ class DataStore(RoomMemberStore, RoomStore, return self.runInteraction("count_users", _count_users) - def count_monthly_users(self): - """Counts the number of users who used this homeserver in the last 30 days - - This method should be refactored with count_daily_users - the only - reason not to is waiting on definition of mau - - Returns: - Defered[int] - """ - def _count_monthly_users(txn): - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - - txn.execute(sql, (thirty_days_ago,)) - count, = txn.fetchone() - return count - - return self.runInteraction("count_monthly_users", _count_monthly_users) - def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 44f37b4c1e..08dffd774f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1150,17 +1150,16 @@ class SQLBaseStore(object): defer.returnValue(retval) def get_user_count_txn(self, txn): - """Get a total number of registerd users in the users list. + """Get a total number of registered users in the users list. Args: txn : Transaction object Returns: - defer.Deferred: resolves to int + int : number of users """ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" txn.execute(sql_count) - count = txn.fetchone()[0] - defer.returnValue(count) + return txn.fetchone()[0] def _simple_search_list(self, table, term, col, retcols, desc="_simple_search_list"): diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index b8cefd43d6..8fc678fa67 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): + self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) + @defer.inlineCallbacks def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, now=None): if not now: @@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - + yield self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -94,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + def update(): to_update = self._batch_row_update self._batch_row_update = {} diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e8e5a0fe44..025a7fb6d9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable @@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(chunk)) - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering, - ) + + if not backfilled: + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) + for event, context in chunk: if context.app_service: origin_type = "local" @@ -1431,88 +1436,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): - """Given a list of event ids, check if we have already processed and - stored them as non outliers. - """ - rows = yield self._simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) - - defer.returnValue(set(r["event_id"] for r in rows)) - - @defer.inlineCallbacks - def have_seen_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Args: - event_ids (iterable[str]): - - Returns: - Deferred[set[str]]: The events we have already seen. - """ - results = set() - - def have_seen_events_txn(txn, chunk): - sql = ( - "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" - % (",".join("?" * len(chunk)), ) - ) - txn.execute(sql, chunk) - for (event_id, ) in txn: - results.add(event_id) - - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), - []): - yield self.runInteraction( - "have_seen_events", - have_seen_events_txn, - chunk, - ) - defer.returnValue(results) - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_rejection_reasons", f) - - @defer.inlineCallbacks def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -1988,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore max_depth = max(row[0] for row in rows) if max_depth <= token.topological: - # We need to ensure we don't delete all the events from the datanase + # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) raise SynapseError( diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 9b4cfeb899..59822178ff 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from collections import namedtuple @@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.prefill((original_ev.event_id,), cache_entry) defer.returnValue(cache_entry) + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self._simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + defer.returnValue(set(r["event_id"] for r in rows)) + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. + """ + if not event_ids: + return defer.succeed({}) + + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction("get_rejection_reasons", f) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py new file mode 100644 index 0000000000..06f9a75a97 --- /dev/null +++ b/synapse/storage/monthly_active_users.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.util.caches.descriptors import cached + +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the monthly_active_user timestamp +# This means it is not necessary to update the table on every request +LAST_SEEN_GRANULARITY = 60 * 60 * 1000 + + +class MonthlyActiveUsersStore(SQLBaseStore): + def __init__(self, dbconn, hs): + super(MonthlyActiveUsersStore, self).__init__(None, hs) + self._clock = hs.get_clock() + self.hs = hs + self.reserved_users = () + + @defer.inlineCallbacks + def initialise_reserved_users(self, threepids): + # TODO Why can't I do this in init? + store = self.hs.get_datastore() + reserved_user_list = [] + + # Do not add more reserved users than the total allowable number + for tp in threepids[:self.hs.config.max_mau_value]: + user_id = yield store.get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + yield self.upsert_monthly_active_user(user_id) + reserved_user_list.append(user_id) + else: + logger.warning( + "mau limit reserved threepid %s not found in db" % tp + ) + self.reserved_users = tuple(reserved_user_list) + + @defer.inlineCallbacks + def reap_monthly_active_users(self): + """ + Cleans out monthly active user table to ensure that no stale + entries exist. + + Returns: + Deferred[] + """ + def _reap_users(txn): + # Purge stale users + + thirty_days_ago = ( + int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + ) + query_args = [thirty_days_ago] + base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + # questionmarks is a hack to overcome sqlite not supporting + # tuples in 'WHERE IN %s' + questionmarks = '?' * len(self.reserved_users) + + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql + + txn.execute(sql, query_args) + + # If MAU user count still exceeds the MAU threshold, then delete on + # a least recently active basis. + # Note it is not possible to write this query using OFFSET due to + # incompatibilities in how sqlite and postgres support the feature. + # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present + # While Postgres does not require 'LIMIT', but also does not support + # negative LIMIT values. So there is no way to write it that both can + # support + safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) + # Must be greater than zero for postgres + safe_guard = safe_guard if safe_guard > 0 else 0 + query_args = [safe_guard] + + base_sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + ORDER BY timestamp DESC + LIMIT ? + ) + """ + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql + txn.execute(sql, query_args) + + yield self.runInteraction("reap_monthly_active_users", _reap_users) + # It seems poor to invalidate the whole cache, Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self.user_last_seen_monthly_active.invalidate_all() + self.get_monthly_active_count.invalidate_all() + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + + txn.execute(sql) + count, = txn.fetchone() + return count + return self.runInteraction("count_users", _count_users) + + def upsert_monthly_active_user(self, user_id): + """ + Updates or inserts monthly active user member + Arguments: + user_id (str): user to add/update + Deferred[bool]: True if a new entry was created, False if an + existing one was updated. + """ + is_insert = self._simple_upsert( + desc="upsert_monthly_active_user", + table="monthly_active_users", + keyvalues={ + "user_id": user_id, + }, + values={ + "timestamp": int(self._clock.time_msec()), + }, + lock=False, + ) + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return(self._simple_select_one_onecol( + table="monthly_active_users", + keyvalues={ + "user_id": user_id, + }, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + )) + + @defer.inlineCallbacks + def populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + if self.hs.config.limit_usage_by_mau: + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + now = self.hs.get_clock().time_msec() + + # We want to reduce to the total number of db writes, and are happy + # to trade accuracy of timestamp in order to lighten load. This means + # We always insert new users (where MAU threshold has not been reached), + # but only update if we have not previously seen the user for + # LAST_SEEN_GRANULARITY ms + if last_seen_timestamp is None: + count = yield self.get_monthly_active_count() + if count < self.hs.config.max_mau_value: + yield self.upsert_monthly_active_user(user_id) + elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: + yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b290f834b3..b364719312 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 50 +SCHEMA_VERSION = 51 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3147fb6827..3378fc77d1 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple( class RoomWorkerStore(SQLBaseStore): + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A namedtuple containing the room information, or an empty list. + """ + return self._simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", @@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A namedtuple containing the room information, or an empty list. - """ - return self._simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10dce21cea..9b4e6d6aa8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql new file mode 100644 index 0000000000..c9d537d5a3 --- /dev/null +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -0,0 +1,27 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of monthly active users, for use where blocking based on mau limits +CREATE TABLE monthly_active_users ( + user_id TEXT NOT NULL, + -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting + -- on updates, Granularity of updates governed by + -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY + -- Measured in ms since epoch. + timestamp BIGINT NOT NULL +); + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b27b3ae144..dd03c4168b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -21,15 +21,17 @@ from six.moves import range from twisted.internet import defer +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError +from synapse.storage._base import SQLBaseStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine +from synapse.storage.events_worker import EventsWorkerStore from synapse.util.caches import get_cache_factor_for, intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) @@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupWorkerStore(SQLBaseStore): +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ @@ -61,6 +64,30 @@ class StateGroupWorkerStore(SQLBaseStore): "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") ) + @defer.inlineCallbacks + def get_room_version(self, room_id): + """Get the room_version of a given room + + Args: + room_id (str) + + Returns: + Deferred[str] + + Raises: + NotFoundError if the room is unknown + """ + # for now we do this by looking at the create event. We may want to cache this + # more intelligently in future. + state_ids = yield self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + if not create_id: + raise NotFoundError("Unknown room") + + create_event = yield self.get_event(create_id) + defer.returnValue(create_event.content.get("room_version", "1")) + @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): """Get the current state event ids for a room based on the @@ -89,6 +116,69 @@ class StateGroupWorkerStore(SQLBaseStore): _get_current_state_ids_txn, ) + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + Args: + room_id (str) + types (list[(Str, (Str|None))]): List of (type, state_key) tuples + which are used to filter the state fetched. `state_key` may be + None, which matches any `state_key` + filtered_types (list[Str]|None): List of types to apply the above filter to. + Returns: + deferred: dict of (type, state_key) -> event + """ + + include_other_types = False if filtered_types is None else True + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? %s""" + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) + ) + for etype, state_key in types + ] + + if include_other_types: + unique_types = set(filtered_types) + clause_to_args.append( + ( + "AND type <> ? " * len(unique_types), + list(unique_types) + ) + ) + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + for where_clause, where_args in clause_to_args: + args = [room_id] + args.extend(where_args) + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results + + return self.runInteraction( + "get_filtered_current_state_ids", + _get_filtered_current_state_ids_txn, + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -362,8 +452,7 @@ class StateGroupWorkerStore(SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - deferred: A list of dicts corresponding to the event_ids given. - The dicts are mappings from (type, state_key) -> state_events + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ event_to_groups = yield self._get_state_group_for_events( event_ids, @@ -391,7 +480,8 @@ class StateGroupWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): """ - Get the state dicts corresponding to a list of events + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned @@ -404,7 +494,7 @@ class StateGroupWorkerStore(SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - A deferred dict from event_id -> (type, state_key) -> state_event + A deferred dict from event_id -> (type, state_key) -> event_id """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9f2b74ac6..4c296d72c0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. @@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. diff --git a/synapse/util/async.py b/synapse/util/async_helpers.py index a7094e2fb4..9b3f2f4b96 100644 --- a/synapse/util/async.py +++ b/synapse/util/async_helpers.py @@ -188,62 +188,30 @@ class Linearizer(object): # things blocked from executing. self.key_to_defer = {} - @defer.inlineCallbacks def queue(self, key): + # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. + # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not + # propagated inside inlineCallbacks until Twisted 18.7) entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items - # When on of the things currently executing finishes it will callback + # When one of the things currently executing finishes it will callback # this item so that it can continue executing. if entry[0] >= self.max_count: - new_defer = defer.Deferred() - entry[1][new_defer] = 1 - - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) - try: - yield make_deferred_yieldable(new_defer) - except Exception as e: - if isinstance(e, CancelledError): - logger.info( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, - ) - else: - logger.warn( - "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, - ) - - # we just have to take ourselves back out of the queue. - del entry[1][new_defer] - raise - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (This needs to happen while we hold the lock, and the context manager's exit - # code must be synchronous, so this is the only sensible place.) - yield self._clock.sleep(0) - + res = self._await_lock(key) else: logger.info( "Acquired uncontended linearizer lock %r for key %r", self.name, key, ) entry[0] += 1 + res = defer.succeed(None) + + # once we successfully get the lock, we need to return a context manager which + # will release the lock. @contextmanager - def _ctx_manager(): + def _ctx_manager(_): try: yield finally: @@ -264,7 +232,64 @@ class Linearizer(object): # map. del self.key_to_defer[key] - defer.returnValue(_ctx_manager()) + res.addCallback(_ctx_manager) + return res + + def _await_lock(self, key): + """Helper for queue: adds a deferred to the queue + + Assumes that we've already checked that we've reached the limit of the number + of lock-holders we allow. Creates a new deferred which is added to the list, and + adds some management around cancellations. + + Returns the deferred, which will callback once we have secured the lock. + + """ + entry = self.key_to_defer[key] + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + + new_defer = make_deferred_yieldable(defer.Deferred()) + entry[1][new_defer] = 1 + + def cb(_r): + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + return self._clock.sleep(0) + + def eb(e): + logger.info("defer %r got err %r", new_defer, e) + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + return e + + new_defer.addCallbacks(cb, eb) + return new_defer class ReadWriteLock(object): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 861c24809c..187510576a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,7 +25,7 @@ from six import itervalues, string_types from twisted.internet import defer from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a8491b42d5..afb03b2e1b 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache from synapse.util.logcontext import make_deferred_yieldable, run_in_background diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index d03678b8c8..8318db8d2c 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred class SnapshotCache(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 8dcae50b39..a0c2d37610 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter): context = LoggingContext.current_context() for key, value in self.defaults.items(): setattr(record, key, value) - context.copy_to(record) + + # context should never be None, but if it somehow ends up being, then + # we end up in a death spiral of infinite loops, so let's check, for + # robustness' sake. + if context is not None: + context.copy_to(record) + return True @@ -396,7 +402,9 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context=LoggingContext.sentinel): + def __init__(self, new_context=None): + if new_context is None: + new_context = LoggingContext.sentinel self.new_context = new_context def __enter__(self): @@ -526,7 +534,7 @@ _to_ignore = [ "synapse.util.logcontext", "synapse.http.server", "synapse.storage._base", - "synapse.util.async", + "synapse.util.async_helpers", ] diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index 62a00189cc..ef31458226 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -20,6 +20,8 @@ import time from functools import wraps from inspect import getcallargs +from six import PY3 + _TIME_FUNC_ID = 0 @@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args): logger = logging.getLogger(name) if logger.isEnabledFor(logging.DEBUG): - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename + if PY3: + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename + else: + lineno = f.func_code.co_firstlineno + pathname = f.func_code.co_filename record = logging.LogRecord( name=name, diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 43d9db67ec..6f318c6a29 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,6 +16,7 @@ import random import string +from six import PY3 from six.moves import range _string_with_symbols = ( @@ -34,6 +35,17 @@ def random_string_with_symbols(length): def is_ascii(s): + + if PY3: + if isinstance(s, bytes): + try: + s.decode('ascii').encode('ascii') + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False + return True + try: s.encode("ascii") except UnicodeEncodeError: @@ -49,6 +61,9 @@ def to_ascii(s): If given None then will return None. """ + if PY3: + return s + if s is None: return None diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 1fbcd41115..3baba3225a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -30,7 +30,7 @@ def get_version_string(module): ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], stderr=null, cwd=cwd, - ).strip() + ).strip().decode('ascii') git_branch = "b=" + git_branch except subprocess.CalledProcessError: git_branch = "" @@ -40,7 +40,7 @@ def get_version_string(module): ['git', 'describe', '--exact-match'], stderr=null, cwd=cwd, - ).strip() + ).strip().decode('ascii') git_tag = "t=" + git_tag except subprocess.CalledProcessError: git_tag = "" @@ -50,7 +50,7 @@ def get_version_string(module): ['git', 'rev-parse', '--short', 'HEAD'], stderr=null, cwd=cwd, - ).strip() + ).strip().decode('ascii') except subprocess.CalledProcessError: git_commit = "" @@ -60,7 +60,7 @@ def get_version_string(module): ['git', 'describe', '--dirty=' + dirty_string], stderr=null, cwd=cwd, - ).strip().endswith(dirty_string) + ).strip().decode('ascii').endswith(dirty_string) git_dirty = "dirty" if is_dirty else "" except subprocess.CalledProcessError: @@ -77,8 +77,8 @@ def get_version_string(module): "%s (%s)" % ( module.__version__, git_version, ) - ).encode("ascii") + ) except Exception as e: logger.info("Failed to check for git repository: %s", e) - return module.__version__.encode("ascii") + return module.__version__ |