diff options
Diffstat (limited to 'synapse')
115 files changed, 2315 insertions, 1523 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index e62901b761..43c5821ade 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,4 +17,14 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.3" +try: + from twisted.internet import protocol + from twisted.internet.protocol import Factory + from twisted.names.dns import DNSDatagramProtocol + protocol.Factory.noisy = False + Factory.noisy = False + DNSDatagramProtocol.noisy = False +except ImportError: + pass + +__version__ = "0.33.6" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8d2aa5870a..34382e4e3c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,6 +26,7 @@ import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -775,13 +776,19 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None): + def check_auth_blocking(self, user_id=None, threepid=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: user_id(str|None): If present, checks for presence against existing MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. """ # Never fail an auth check for the server notices users @@ -793,10 +800,12 @@ class Auth(object): raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_uri=self.hs.config.admin_uri, + admin_contact=self.hs.config.admin_contact, limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: + assert not (user_id and threepid) + # If the user is already part of the MAU cohort or a trial user if user_id: timestamp = yield self.store.user_last_seen_monthly_active(user_id) @@ -806,13 +815,17 @@ class Auth(object): is_trial = yield self.store.is_trial_user(user_id) if is_trial: return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self.hs.config, threepid): + return # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( 403, "Monthly Active User Limit Exceeded", - - admin_uri=self.hs.config.admin_uri, + admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, limit_type="monthly_active_user" ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 140dbfe8b8..4124469442 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -240,10 +240,10 @@ class ResourceLimitError(SynapseError): def __init__( self, code, msg, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_uri=None, + admin_contact=None, limit_type=None, ): - self.admin_uri = admin_uri + self.admin_contact = admin_contact self.limit_type = limit_type super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) @@ -251,7 +251,7 @@ class ResourceLimitError(SynapseError): return cs_error( self.msg, self.errcode, - admin_uri=self.admin_uri, + admin_contact=self.admin_contact, limit_type=self.limit_type ) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 186831e118..eed8c67e6a 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -226,7 +226,7 @@ class Filtering(object): jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()) except jsonschema.ValidationError as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) class FilterCollection(object): @@ -251,6 +251,7 @@ class FilterCollection(object): "include_leave", False ) self.event_fields = filter_json.get("event_fields", []) + self.event_format = filter_json.get("event_format", "client") def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 71347912f1..6d9f1ca0ef 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -64,7 +64,7 @@ class ConsentURIBuilder(object): """ mac = hmac.new( key=self._hmac_secret, - msg=user_id, + msg=user_id.encode('ascii'), digestmod=sha256, ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 3b6b9368b8..c3afcc573b 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -24,7 +24,7 @@ try: python_dependencies.check_requirements() except python_dependencies.MissingRequirementError as e: message = "\n".join([ - "Missing Requirement: %s" % (e.message,), + "Missing Requirement: %s" % (str(e),), "To install run:", " pip install --upgrade --force \"%s\"" % (e.dependency,), "", diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 3348a8ec6d..8559e141af 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -51,10 +51,7 @@ class AppserviceSlaveStore( class AppserviceServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = AppserviceSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -139,7 +136,7 @@ def start(config_options): "Synapse appservice", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.appservice" @@ -175,7 +172,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ab79a45646..76aed8c60a 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -74,10 +74,7 @@ class ClientReaderSlavedStore( class ClientReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = ClientReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -156,7 +153,7 @@ def start(config_options): "Synapse client reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.client_reader" @@ -184,7 +181,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index a34c89fa99..9060ab14f6 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -90,10 +90,7 @@ class EventCreatorSlavedStore( class EventCreatorServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = EventCreatorSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -172,7 +169,7 @@ def start(config_options): "Synapse event creator", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.event_creator" @@ -202,7 +199,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7d8105778d..228a297fb8 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -72,10 +72,7 @@ class FederationReaderSlavedStore( class FederationReaderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationReaderSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -143,7 +140,7 @@ def start(config_options): "Synapse federation reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_reader" @@ -171,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index d59007099b..e9a99d76e1 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -78,10 +78,7 @@ class FederationSenderSlaveStore( class FederationSenderServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FederationSenderSlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -163,7 +160,7 @@ def start(config_options): "Synapse federation sender", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_sender" @@ -204,7 +201,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) _base.start_worker_reactor("synapse-federation-sender", config) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 8d484c1cd4..fc4b25de1c 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -148,10 +148,7 @@ class FrontendProxySlavedStore( class FrontendProxyServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = FrontendProxySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -231,7 +228,7 @@ def start(config_options): "Synapse frontend proxy", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.frontend_proxy" @@ -261,7 +258,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 005921dcf7..e3f0d99a3f 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer -from synapse.storage import are_all_users_on_domain +from synapse.storage import DataStore, are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.util.caches import CACHE_SIZE_FACTOR @@ -111,6 +111,8 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + def _listener_http(self, config, listener_config): port = listener_config["port"] bind_addresses = listener_config["bind_addresses"] @@ -299,12 +301,16 @@ class SynapseHomeServer(HomeServer): try: database_engine.check_database(db_conn.cursor()) except IncorrectDatabaseSetup as e: - quit_with_error(e.message) + quit_with_error(str(e)) # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") +registered_reserved_users_mau_gauge = Gauge( + "synapse_admin_mau:registered_reserved_users", + "Registered users with reserved threepids" +) def setup(config_options): @@ -322,7 +328,7 @@ def setup(config_options): config_options, ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) if not config: @@ -356,13 +362,13 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn(run_new_connection=False) - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) + with hs.get_db_conn(run_new_connection=False) as db_conn: + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) - hs.run_startup_checks(db_conn, database_engine) + hs.run_startup_checks(db_conn, database_engine) - db_conn.commit() + db_conn.commit() except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" @@ -378,10 +384,8 @@ def setup(config_options): def start(): hs.get_pusherpool().start() - hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_federation_client().start_get_pdu_cache() reactor.callWhenRunning(start) @@ -451,6 +455,10 @@ def run(hs): stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) stats["total_users"] = yield hs.get_datastore().count_all_users() total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() @@ -529,10 +537,14 @@ def run(hs): @defer.inlineCallbacks def generate_monthly_active_users(): - count = 0 + current_mau_count = 0 + reserved_count = 0 + store = hs.get_datastore() if hs.config.limit_usage_by_mau: - count = yield hs.get_datastore().get_monthly_active_count() - current_mau_gauge.set(float(count)) + current_mau_count = yield store.get_monthly_active_count() + reserved_count = yield store.get_registered_reserved_users_count() + current_mau_gauge.set(float(current_mau_count)) + registered_reserved_users_mau_gauge.set(float(reserved_count)) max_mau_gauge.set(float(hs.config.max_mau_value)) hs.get_datastore().initialise_reserved_users( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index fd1f6cbf7e..acc0487adc 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -60,10 +60,7 @@ class MediaRepositorySlavedStore( class MediaRepositoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = MediaRepositorySlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -136,7 +133,7 @@ def start(config_options): "Synapse media repository", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.media_repository" @@ -171,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index a4fc7e91fa..630dcda478 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -78,10 +78,7 @@ class PusherSlaveStore( class PusherServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = PusherSlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = PusherSlaveStore def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) @@ -194,7 +191,7 @@ def start(config_options): "Synapse pusher", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.pusher" @@ -231,7 +228,6 @@ def start(config_options): def start(): ps.get_pusherpool().start() ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 27e1998660..9a7fc6ee9d 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -249,10 +249,7 @@ class SynchrotronApplicationService(object): class SynchrotronServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = SynchrotronSlavedStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -413,7 +410,7 @@ def start(config_options): "Synapse synchrotron", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.synchrotron" @@ -438,7 +435,6 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() - ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py deleted file mode 100755 index d658f967ba..0000000000 --- a/synapse/app/synctl.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import collections -import errno -import glob -import os -import os.path -import signal -import subprocess -import sys -import time - -from six import iteritems - -import yaml - -SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] - -GREEN = "\x1b[1;32m" -YELLOW = "\x1b[1;33m" -RED = "\x1b[1;31m" -NORMAL = "\x1b[m" - - -def pid_running(pid): - try: - os.kill(pid, 0) - return True - except OSError as err: - if err.errno == errno.EPERM: - return True - return False - - -def write(message, colour=NORMAL, stream=sys.stdout): - if colour == NORMAL: - stream.write(message + "\n") - else: - stream.write(colour + message + NORMAL + "\n") - - -def abort(message, colour=RED, stream=sys.stderr): - write(message, colour, stream) - sys.exit(1) - - -def start(configfile): - write("Starting ...") - args = SYNAPSE - args.extend(["--daemonize", "-c", configfile]) - - try: - subprocess.check_call(args) - write("started synapse.app.homeserver(%r)" % - (configfile,), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting (exit code: %d); see above for logs" % e.returncode, - colour=RED, - ) - - -def start_worker(app, configfile, worker_configfile): - args = [ - "python", "-B", - "-m", app, - "-c", configfile, - "-c", worker_configfile - ] - - try: - subprocess.check_call(args) - write("started %s(%r)" % (app, worker_configfile), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting %s(%r) (exit code: %d); see above for logs" % ( - app, worker_configfile, e.returncode, - ), - colour=RED, - ) - - -def stop(pidfile, app): - if os.path.exists(pidfile): - pid = int(open(pidfile).read()) - try: - os.kill(pid, signal.SIGTERM) - write("stopped %s" % (app,), colour=GREEN) - except OSError as err: - if err.errno == errno.ESRCH: - write("%s not running" % (app,), colour=YELLOW) - elif err.errno == errno.EPERM: - abort("Cannot stop %s: Operation not permitted" % (app,)) - else: - abort("Cannot stop %s: Unknown error" % (app,)) - - -Worker = collections.namedtuple("Worker", [ - "app", "configfile", "pidfile", "cache_factor" -]) - - -def main(): - - parser = argparse.ArgumentParser() - - parser.add_argument( - "action", - choices=["start", "stop", "restart"], - help="whether to start, stop or restart the synapse", - ) - parser.add_argument( - "configfile", - nargs="?", - default="homeserver.yaml", - help="the homeserver config file, defaults to homeserver.yaml", - ) - parser.add_argument( - "-w", "--worker", - metavar="WORKERCONFIG", - help="start or stop a single worker", - ) - parser.add_argument( - "-a", "--all-processes", - metavar="WORKERCONFIGDIR", - help="start or stop all the workers in the given directory" - " and the main synapse process", - ) - - options = parser.parse_args() - - if options.worker and options.all_processes: - write( - 'Cannot use "--worker" with "--all-processes"', - stream=sys.stderr - ) - sys.exit(1) - - configfile = options.configfile - - if not os.path.exists(configfile): - write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), options.configfile - ), - stream=sys.stderr, - ) - sys.exit(1) - - with open(configfile) as stream: - config = yaml.load(stream) - - pidfile = config["pid_file"] - cache_factor = config.get("synctl_cache_factor") - start_stop_synapse = True - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - - cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): - os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) - - worker_configfiles = [] - if options.worker: - start_stop_synapse = False - worker_configfile = options.worker - if not os.path.exists(worker_configfile): - write( - "No worker config found at %r" % (worker_configfile,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.append(worker_configfile) - - if options.all_processes: - # To start the main synapse with -a you need to add a worker file - # with worker_app == "synapse.app.homeserver" - start_stop_synapse = False - worker_configdir = options.all_processes - if not os.path.isdir(worker_configdir): - write( - "No worker config directory found at %r" % (worker_configdir,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.extend(sorted(glob.glob( - os.path.join(worker_configdir, "*.yaml") - ))) - - workers = [] - for worker_configfile in worker_configfiles: - with open(worker_configfile) as stream: - worker_config = yaml.load(stream) - worker_app = worker_config["worker_app"] - if worker_app == "synapse.app.homeserver": - # We need to special case all of this to pick up options that may - # be set in the main config file or in this worker config file. - worker_pidfile = ( - worker_config.get("pid_file") - or pidfile - ) - worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor - daemonize = worker_config.get("daemonize") or config.get("daemonize") - assert daemonize, "Main process must have daemonize set to true" - - # The master process doesn't support using worker_* config. - for key in worker_config: - if key == "worker_app": # But we allow worker_app - continue - assert not key.startswith("worker_"), \ - "Main process cannot use worker_* config" - else: - worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, "worker_daemonize") - worker_cache_factor = worker_config.get("synctl_cache_factor") - workers.append(Worker( - worker_app, worker_configfile, worker_pidfile, worker_cache_factor, - )) - - action = options.action - - if action == "stop" or action == "restart": - for worker in workers: - stop(worker.pidfile, worker.app) - - if start_stop_synapse: - stop(pidfile, "synapse.app.homeserver") - - # Wait for synapse to actually shutdown before starting it again - if action == "restart": - running_pids = [] - if start_stop_synapse and os.path.exists(pidfile): - running_pids.append(int(open(pidfile).read())) - for worker in workers: - if os.path.exists(worker.pidfile): - running_pids.append(int(open(worker.pidfile).read())) - if len(running_pids) > 0: - write("Waiting for process to exit before restarting...") - for running_pid in running_pids: - while pid_running(running_pid): - time.sleep(0.2) - write("All processes exited; now restarting...") - - if action == "start" or action == "restart": - if start_stop_synapse: - # Check if synapse is already running - if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): - abort("synapse.app.homeserver already running") - start(configfile) - - for worker in workers: - if worker.cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - - start_worker(worker.app, configfile, worker.configfile) - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - else: - os.environ.pop("SYNAPSE_CACHE_FACTOR", None) - - -if __name__ == "__main__": - main() diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1388a42b59..0a5f62b509 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -94,10 +94,7 @@ class UserDirectorySlaveStore( class UserDirectoryServer(HomeServer): - def setup(self): - logger.info("Setting up.") - self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) - logger.info("Finished setting up.") + DATASTORE_CLASS = UserDirectorySlaveStore def _listen_http(self, listener_config): port = listener_config["port"] @@ -191,7 +188,7 @@ def start(config_options): "Synapse user directory", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.user_dir" @@ -232,7 +229,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6980e5890e..9ccc5a80fc 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib + +from six.moves import urllib from prometheus_client import Counter @@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_user(self, service, user_id): if service.url is None: defer.returnValue(False) - uri = service.url + ("/users/%s" % urllib.quote(user_id)) + uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: response = yield self.get_json(uri, { @@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient): def query_alias(self, service, alias): if service.url is None: defer.returnValue(False) - uri = service.url + ("/rooms/%s" % urllib.quote(alias)) + uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: response = yield self.get_json(uri, { @@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, APP_SERVICE_PREFIX, kind, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: response = yield self.get_json(uri, fields) @@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.quote(protocol) + urllib.parse.quote(protocol) ) try: info = yield self.get_json(uri, {}) @@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient): txn_id = str(txn_id) uri = service.url + ("/transactions/%s" % - urllib.quote(txn_id)) + urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index 58c97a70af..8fccf573ee 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -25,7 +25,7 @@ if __name__ == "__main__": try: config = HomeServerConfig.load_config("", sys.argv[3:]) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) print (getattr(config, key)) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 2fd9c48abf..b8d5690f2b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -21,7 +21,7 @@ from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig from .groups import GroupsConfig -from .jwt import JWTConfig +from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig diff --git a/synapse/config/jwt.py b/synapse/config/jwt_config.py index 51e7f7e003..51e7f7e003 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt_config.py diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 3f187adfc8..e9a936118d 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -227,7 +227,22 @@ def setup_logging(config, use_worker_options=False): # # However this may not be too much of a problem if we are just writing to a file. observer = STDLibLogObserver() + + def _log(event): + + if "log_text" in event: + if event["log_text"].startswith("DNSDatagramProtocol starting on "): + return + + if event["log_text"].startswith("(UDP Port "): + return + + if event["log_text"].startswith("Timing out client"): + return + + return observer(event) + globalLogBeginner.beginLoggingTo( - [observer], + [_log], redirectStandardIO=not config.no_redirect_stdio, ) diff --git a/synapse/config/server.py b/synapse/config/server.py index 8eecd28e7d..c1c7c0105e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -93,7 +93,7 @@ class ServerConfig(Config): # Admin uri to direct users at should their instance become blocked # due to resource constraints - self.admin_uri = config.get("admin_uri", None) + self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None @@ -357,7 +357,7 @@ class ServerConfig(Config): # Homeserver blocking # # How to reach the server admin, used in ResourceLimitError - # admin_uri: 'mailto:admin@server.com' + # admin_contact: 'mailto:admin@server.com' # # Global block config # @@ -404,6 +404,23 @@ class ServerConfig(Config): " service on the given port.") +def is_threepid_reserved(config, threepid): + """Check the threepid against the reserved threepid config + Args: + config(ServerConfig) - to access server config attributes + threepid(dict) - The threepid to test for + + Returns: + boolean Is the threepid undertest reserved_user + """ + + for tp in config.mau_limits_reserved_threepids: + if (threepid['medium'] == tp['medium'] + and threepid['address'] == tp['address']): + return True + return False + + def read_gc_thresholds(thresholds): """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 1a391adec1..02b76dfcfb 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -123,6 +123,6 @@ class ClientTLSOptionsFactory(object): def get_options(self, host): return ClientTLSOptions( - host.decode('utf-8'), + host, CertificateOptions(verify=False).getContext() ) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index e94400b8e2..57d4665e84 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -50,7 +50,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: logger.warn("Error getting key for %r: %s", server_name, e) - if e.status.startswith("4"): + if e.status.startswith(b"4"): # Don't retry for 4xx responses. raise IOError("Cannot get key for %r" % server_name) except (ConnectError, DomainError) as e: @@ -82,6 +82,12 @@ class SynapseKeyClientProtocol(HTTPClient): self._peer = self.transport.getPeer() logger.debug("Connected to %s", self._peer) + if not isinstance(self.path, bytes): + self.path = self.path.encode('ascii') + + if not isinstance(self.host, bytes): + self.host = self.host.encode('ascii') + self.sendCommand(b"GET", self.path) if self.host: self.sendHeader(b"Host", self.host) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 30e2742102..d89f94c219 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,9 +16,10 @@ import hashlib import logging -import urllib from collections import namedtuple +from six.moves import urllib + from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -40,6 +41,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyclient import fetch_server_key from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( + LoggingContext, PreserveLoggingContext, preserve_fn, run_in_background, @@ -216,23 +218,34 @@ class Keyring(object): servers have completed. Follows the synapse rules of logcontext preservation. """ + loop_count = 1 while True: wait_on = [ - self.key_downloads[server_name] + (server_name, self.key_downloads[server_name]) for server_name in server_names if server_name in self.key_downloads ] - if wait_on: - with PreserveLoggingContext(): - yield defer.DeferredList(wait_on) - else: + if not wait_on: break + logger.info( + "Waiting for existing lookups for %s to complete [loop %i]", + [w[0] for w in wait_on], loop_count, + ) + with PreserveLoggingContext(): + yield defer.DeferredList((w[1] for w in wait_on)) + + loop_count += 1 + + ctx = LoggingContext.current_context() def rm(r, server_name_): - self.key_downloads.pop(server_name_, None) + with PreserveLoggingContext(ctx): + logger.debug("Releasing key lookup lock on %s", server_name_) + self.key_downloads.pop(server_name_, None) return r for server_name, deferred in server_to_deferred.items(): + logger.debug("Got key lookup lock on %s", server_name) self.key_downloads[server_name] = deferred deferred.addBoth(rm, server_name) @@ -432,7 +445,7 @@ class Keyring(object): # an incoming request. query_response = yield self.client.post_json( destination=perspective_name, - path=b"/_matrix/key/v2/query", + path="/_matrix/key/v2/query", data={ u"server_keys": { server_name: { @@ -513,8 +526,8 @@ class Keyring(object): (response, tls_certificate) = yield fetch_server_key( server_name, self.hs.tls_client_options_factory, - path=(b"/_matrix/key/v2/server/%s" % ( - urllib.quote(requested_key_id), + path=("/_matrix/key/v2/server/%s" % ( + urllib.parse.quote(requested_key_id), )).encode("ascii"), ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 6baeccca38..af3eee95b9 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -98,9 +98,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise SynapseError( + raise AuthError( 403, - "Room %r does not exist" % (event.room_id,) + "No create event in auth events", ) creating_domain = get_domain_from_id(event.room_id) @@ -155,10 +155,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): if user_level < invite_level: raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) + 403, "You don't have permission to invite users", ) else: logger.debug("Allowing! %s", event) @@ -305,7 +302,7 @@ def _is_membership_change_allowed(event, auth_events): if user_level < invite_level: raise AuthError( - 403, "You cannot invite user %s." % target_user_id + 403, "You don't have permission to invite users", ) elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 51f9084b90..12f1eb0a3e 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -13,13 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from distutils.util import strtobool + +import six + from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents -# bugs where we accidentally share e.g. signature dicts. However, converting -# a dict to frozen_dicts is expensive. -USE_FROZEN_DICTS = True +# bugs where we accidentally share e.g. signature dicts. However, converting a +# dict to frozen_dicts is expensive. +# +# NOTE: This is overridden by the configuration by the Synapse worker apps, but +# for the sake of tests, it is set here while it cannot be configured on the +# homeserver object itself. +USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) class _EventInternalMetadata(object): @@ -147,6 +156,9 @@ class EventBase(object): def items(self): return list(self._event_dict.items()) + def keys(self): + return six.iterkeys(self._event_dict) + class FrozenEvent(EventBase): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c11798093d..b7ad729c63 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import namedtuple import six from twisted.internet import defer +from twisted.internet.defer import DeferredList -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict +from synapse.types import get_domain_from_id from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) @@ -133,34 +136,45 @@ class FederationBase(object): * throws a SynapseError if the signature check failed. The deferreds run their callbacks in the sentinel logcontext. """ - - redacted_pdus = [ - prune_event(pdu) - for pdu in pdus - ] - - deferreds = self.keyring.verify_json_objects_for_server([ - (p.origin, p.get_pdu_json()) - for p in redacted_pdus - ]) + deferreds = _check_sigs_on_pdus(self.keyring, pdus) ctx = logcontext.LoggingContext.current_context() - def callback(_, pdu, redacted): + def callback(_, pdu): with logcontext.PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted + # let's try to distinguish between failures because the event was + # redacted (which are somewhat expected) vs actual ball-tampering + # incidents. + # + # This is just a heuristic, so we just assume that if the keys are + # about the same between the redacted and received events, then the + # received event was probably a redacted copy (but we then use our + # *actual* redacted copy to be on the safe side.) + redacted_event = prune_event(pdu) + if ( + set(redacted_event.keys()) == set(pdu.keys()) and + set(six.iterkeys(redacted_event.content)) + == set(six.iterkeys(pdu.content)) + ): + logger.info( + "Event %s seems to have been redacted; using our redacted " + "copy", + pdu.event_id, + ) + else: + logger.warning( + "Event %s content has been tampered, redacting", + pdu.event_id, pdu.get_pdu_json(), + ) + return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json() ) - return redacted + return prune_event(pdu) return pdu @@ -168,21 +182,121 @@ class FederationBase(object): failure.trap(SynapseError) with logcontext.PreserveLoggingContext(ctx): logger.warn( - "Signature check failed for %s", - pdu.event_id, + "Signature check failed for %s: %s", + pdu.event_id, failure.getErrorMessage(), ) return failure - for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( callback, errback, - callbackArgs=[pdu, redacted], + callbackArgs=[pdu], errbackArgs=[pdu], ) return deferreds +class PduToCheckSig(namedtuple("PduToCheckSig", [ + "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds", +])): + pass + + +def _check_sigs_on_pdus(keyring, pdus): + """Check that the given events are correctly signed + + Args: + keyring (synapse.crypto.Keyring): keyring object to do the checks + pdus (Collection[EventBase]): the events to be checked + + Returns: + List[Deferred]: a Deferred for each event in pdus, which will either succeed if + the signatures are valid, or fail (with a SynapseError) if not. + """ + + # (currently this is written assuming the v1 room structure; we'll probably want a + # separate function for checking v2 rooms) + + # we want to check that the event is signed by: + # + # (a) the server which created the event_id + # + # (b) the sender's server. + # + # - except in the case of invites created from a 3pid invite, which are exempt + # from this check, because the sender has to match that of the original 3pid + # invite, but the event may come from a different HS, for reasons that I don't + # entirely grok (why do the senders have to match? and if they do, why doesn't the + # joining server ask the inviting server to do the switcheroo with + # exchange_third_party_invite?). + # + # That's pretty awful, since redacting such an invite will render it invalid + # (because it will then look like a regular invite without a valid signature), + # and signatures are *supposed* to be valid whether or not an event has been + # redacted. But this isn't the worst of the ways that 3pid invites are broken. + # + # let's start by getting the domain for each pdu, and flattening the event back + # to JSON. + pdus_to_check = [ + PduToCheckSig( + pdu=p, + redacted_pdu_json=prune_event(p).get_pdu_json(), + event_id_domain=get_domain_from_id(p.event_id), + sender_domain=get_domain_from_id(p.sender), + deferreds=[], + ) + for p in pdus + ] + + # first make sure that the event is signed by the event_id's domain + deferreds = keyring.verify_json_objects_for_server([ + (p.event_id_domain, p.redacted_pdu_json) + for p in pdus_to_check + ]) + + for p, d in zip(pdus_to_check, deferreds): + p.deferreds.append(d) + + # now let's look for events where the sender's domain is different to the + # event id's domain (normally only the case for joins/leaves), and add additional + # checks. + pdus_to_check_sender = [ + p for p in pdus_to_check + if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu) + ] + + more_deferreds = keyring.verify_json_objects_for_server([ + (p.sender_domain, p.redacted_pdu_json) + for p in pdus_to_check_sender + ]) + + for p, d in zip(pdus_to_check_sender, more_deferreds): + p.deferreds.append(d) + + # replace lists of deferreds with single Deferreds + return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] + + +def _flatten_deferred_list(deferreds): + """Given a list of one or more deferreds, either return the single deferred, or + combine into a DeferredList. + """ + if len(deferreds) > 1: + return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) + else: + assert len(deferreds) == 1 + return deferreds[0] + + +def _is_invite_via_3pid(event): + return ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + def event_from_pdu_json(pdu_json, outlier=False): """Construct a FrozenEvent from an event json received over federation diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9f3c2d352..d05ed91d64 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -66,6 +66,14 @@ class FederationClient(FederationBase): self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() + self._get_pdu_cache = ExpiringCache( + cache_name="get_pdu_cache", + clock=self._clock, + max_len=1000, + expiry_ms=120 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -82,17 +90,6 @@ class FederationClient(FederationBase): if destination_dict: self.pdu_destination_tried[event_id] = destination_dict - def start_get_pdu_cache(self): - self._get_pdu_cache = ExpiringCache( - cache_name="get_pdu_cache", - clock=self._clock, - max_len=1000, - expiry_ms=120 * 1000, - reset_expiry_on_get=False, - ) - - self._get_pdu_cache.start() - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False, ignore_backoff=False): @@ -212,8 +209,6 @@ class FederationClient(FederationBase): Will attempt to get the PDU from each destination in the list until one succeeds. - This will persist the PDU locally upon receipt. - Args: destinations (list): Which home servers to query event_id (str): event to fetch @@ -229,10 +224,9 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. - if self._get_pdu_cache: - ev = self._get_pdu_cache.get(event_id) - if ev: - defer.returnValue(ev) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) @@ -271,10 +265,10 @@ class FederationClient(FederationBase): event_id, destination, e, ) except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: - logger.info(e.message) + logger.info(str(e)) continue except Exception as e: pdu_attempts[destination] = now @@ -285,7 +279,7 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None and signed_pdu: + if signed_pdu: self._get_pdu_cache[event_id] = signed_pdu defer.returnValue(signed_pdu) @@ -293,8 +287,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. + """Requests all of the room state at a given event from a remote home server. Args: destination (str): The remote homeserver to query for the state. @@ -302,9 +295,10 @@ class FederationClient(FederationBase): event_id (str): The id of the event we want the state at. Returns: - Deferred: Results in a list of PDUs. + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. """ - try: # First we try and ask for just the IDs, as thats far quicker if # we have most of the state and auth_chain already. @@ -510,7 +504,7 @@ class FederationClient(FederationBase): else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.message, + description, destination, e.code, e.args[0], ) except Exception: logger.warn( @@ -875,7 +869,7 @@ class FederationClient(FederationBase): except Exception as e: logger.exception( "Failed to send_third_party_invite via %s: %s", - destination, e.message + destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3e0cd294a1..819e8f7331 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -46,6 +46,7 @@ from synapse.replication.http.federation import ( from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logcontext import nested_logging_context from synapse.util.logutils import log_function # when processing incoming transactions, we try to handle multiple rooms in @@ -99,7 +100,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function - def on_incoming_transaction(self, transaction_data): + def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -108,34 +109,33 @@ class FederationServer(FederationBase): if not transaction.transaction_id: raise Exception("Transaction missing transaction_id") - if not transaction.origin: - raise Exception("Transaction missing origin") logger.debug("[%s] Got transaction", transaction.transaction_id) # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with (yield self._transaction_linearizer.queue( - (transaction.origin, transaction.transaction_id), + (origin, transaction.transaction_id), )): result = yield self._handle_incoming_transaction( - transaction, request_time, + origin, transaction, request_time, ) defer.returnValue(result) @defer.inlineCallbacks - def _handle_incoming_transaction(self, transaction, request_time): + def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: + origin (unicode): the server making the request transaction (Transaction): incoming transaction request_time (int): timestamp that the HTTP request arrived at Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(transaction) + response = yield self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -149,7 +149,7 @@ class FederationServer(FederationBase): received_pdus_counter.inc(len(transaction.pdus)) - origin_host, _ = parse_server_name(transaction.origin) + origin_host, _ = parse_server_name(origin) pdus_by_room = {} @@ -188,21 +188,22 @@ class FederationServer(FederationBase): for pdu in pdus_by_room[room_id]: event_id = pdu.event_id - try: - yield self._handle_received_pdu( - transaction.origin, pdu - ) - pdu_results[event_id] = {} - except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s: %s", - event_id, f.getTraceback().rstrip(), - ) + with nested_logging_context(event_id): + try: + yield self._handle_received_pdu( + origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), + ) yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), @@ -212,7 +213,7 @@ class FederationServer(FederationBase): if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): yield self.received_edu( - transaction.origin, + origin, edu.edu_type, edu.content ) @@ -224,6 +225,7 @@ class FederationServer(FederationBase): logger.debug("Returning: %s", str(response)) yield self.transaction_actions.set_response( + origin, transaction, 200, response ) @@ -618,7 +620,7 @@ class FederationServer(FederationBase): ) yield self.handler.on_receive_pdu( - origin, pdu, get_missing=True, sent_to_us_directly=True, + origin, pdu, sent_to_us_directly=True, ) def __str__(self): @@ -838,9 +840,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): ) return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, + edu_type=edu_type, + origin=origin, + content=content, ) def on_query(self, query_type, args): @@ -851,6 +853,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): return handler(args) return self._get_query_client( - query_type=query_type, - args=args, + query_type=query_type, + args=args, ) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 9146215c21..74ffd13b4f 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -36,7 +36,7 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, transaction): + def have_responded(self, origin, transaction): """ Have we already responded to a transaction with the same id and origin? @@ -50,11 +50,11 @@ class TransactionActions(object): "transaction_id") return self.store.get_received_txn_response( - transaction.transaction_id, transaction.origin + transaction.transaction_id, origin ) @log_function - def set_response(self, transaction, code, response): + def set_response(self, origin, transaction, code, response): """ Persist how we responded to a transaction. Returns: @@ -66,7 +66,7 @@ class TransactionActions(object): return self.store.set_received_txn_response( transaction.transaction_id, - transaction.origin, + origin, code, response, ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0bb468385d..6f5995735a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -32,7 +32,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems from sortedcontainers import SortedDict @@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object): user_ids = set( user_id - for uids in itervalues(self.presence_changed) + for uids in self.presence_changed.values() for user_id in uids ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 94d7423d01..98b5950800 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -137,26 +137,6 @@ class TransactionQueue(object): self._processing_pending_presence = False - def can_send_to(self, destination): - """Can we send messages to the given server? - - We can't send messages to ourselves. If we are running on localhost - then we can only federation with other servers running on localhost. - Otherwise we only federate with servers on a public domain. - - Args: - destination(str): The server we are possibly trying to send to. - Returns: - bool: True if we can send to the server. - """ - - if destination == self.server_name: - return False - if self.server_name.startswith("localhost"): - return destination.startswith("localhost") - else: - return not destination.startswith("localhost") - def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -279,10 +259,7 @@ class TransactionQueue(object): self._order += 1 destinations = set(destinations) - destinations = set( - dest for dest in destinations if self.can_send_to(dest) - ) - + destinations.discard(self.server_name) logger.debug("Sending to: %s", str(destinations)) if not destinations: @@ -358,7 +335,7 @@ class TransactionQueue(object): for destinations, states in hosts_and_states: for destination in destinations: - if not self.can_send_to(destination): + if destination == self.server_name: continue self.pending_presence_by_dest.setdefault( @@ -377,7 +354,8 @@ class TransactionQueue(object): content=content, ) - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") return sent_edus_counter.inc() @@ -392,10 +370,8 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) def send_device_messages(self, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending device update to ourselves") return self._attempt_new_transaction(destination) @@ -463,7 +439,19 @@ class TransactionQueue(object): # pending_transactions flag. pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + + # We can only include at most 50 PDUs per transactions + pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] + if leftover_pdus: + self.pending_pdus_by_dest[destination] = leftover_pdus + pending_edus = self.pending_edus_by_dest.pop(destination, []) + + # We can only include at most 100 EDUs per transactions + pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] + if leftover_edus: + self.pending_edus_by_dest[destination] = leftover_edus + pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_edus.extend( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1054441ca5..2ab973d6c8 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import urllib + +from six.moves import urllib from twisted.internet import defer @@ -951,4 +952,4 @@ def _create_path(prefix, path, *args): Returns: str """ - return prefix + path % tuple(urllib.quote(arg, "") for arg in args) + return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 7a993fd1cf..2f874b4838 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -90,8 +90,8 @@ class Authenticator(object): @defer.inlineCallbacks def authenticate_request(self, request, content): json_request = { - "method": request.method, - "uri": request.uri, + "method": request.method.decode('ascii'), + "uri": request.uri.decode('ascii'), "destination": self.server_name, "signatures": {}, } @@ -252,7 +252,7 @@ class BaseFederationServlet(object): by the callback method. None if the request has already been handled. """ content = None - if request.method in ["PUT", "POST"]: + if request.method in [b"PUT", b"POST"]: # TODO: Handle other method types? other content types? content = parse_json_object_from_request(request) @@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - transaction_data + origin, transaction_data, ) except Exception: logger.exception("on_incoming_transaction failed") @@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet): return self.handler.on_context_state_request( origin, context, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet): return self.handler.on_state_ids_request( origin, room_id, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/" def on_GET(self, origin, content, query, context): - versions = query["v"] - limits = query["limit"] + versions = [x.decode('ascii') for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) - if not limits: + if not limit: return defer.succeed((400, {"error": "Did not include limit param"})) - limit = int(limits[-1]) - return self.handler.on_backfill_request(origin, context, versions, limit) @@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k: v[0].decode("utf-8") for k, v in query.items()} + {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} ) @@ -630,14 +628,14 @@ class OpenIdUserInfo(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query): - token = query.get("access_token", [None])[0] + token = query.get(b"access_token", [None])[0] if token is None: defer.returnValue((401, { "errcode": "M_MISSING_TOKEN", "error": "Access Token required" })) return - user_id = yield self.handler.on_openid_userinfo(token) + user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) if user_id is None: defer.returnValue((401, { diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4a81bd2ba9..2a5eab124f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -895,22 +895,24 @@ class AuthHandler(BaseHandler): Args: password (unicode): Password to hash. - stored_hash (unicode): Expected hash value. + stored_hash (bytes): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash.encode('utf8') + stored_hash ) if stored_hash: + if not isinstance(stored_hash, bytes): + stored_hash = stored_hash.encode('ascii') + return make_deferred_yieldable( threads.deferToThreadPool( self.hs.get_reactor(), diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ef866da1b6..18741c5fac 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -20,7 +20,14 @@ import string from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + NotFoundError, + StoreError, + SynapseError, +) from synapse.types import RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -109,7 +116,13 @@ class DirectoryHandler(BaseHandler): def delete_association(self, requester, user_id, room_alias): # association deletion for human users - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + try: + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown room alias") + raise + if not can_delete: raise AuthError( 403, "You don't have permission to delete the alias.", @@ -320,7 +333,7 @@ class DirectoryHandler(BaseHandler): def _user_can_delete_alias(self, alias, user_id): creator = yield self.store.get_room_alias_creator(alias.to_string()) - if creator and creator == user_id: + if creator is not None and creator == user_id: defer.returnValue(True) is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5816bf8b4f..9dc46aa15f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -330,7 +330,8 @@ class E2eKeysHandler(object): (algorithm, key_id, ex_json, key) ) else: - new_keys.append((algorithm, key_id, encode_canonical_json(key))) + new_keys.append(( + algorithm, key_id, encode_canonical_json(key).decode('ascii'))) yield self.store.add_e2e_one_time_keys( user_id, device_id, time_now, new_keys @@ -340,7 +341,7 @@ class E2eKeysHandler(object): def _exception_to_failure(e): if isinstance(e, CodeMessageException): return { - "status": e.code, "message": e.message, + "status": e.code, "message": str(e), } if isinstance(e, NotRetryingDestination): @@ -358,7 +359,7 @@ def _exception_to_failure(e): # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. return { - "status": 503, "message": str(e.message), + "status": 503, "message": str(e), } diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0ebf0fd188..45d955e6f5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -18,7 +18,6 @@ import itertools import logging -import sys import six from six import iteritems, itervalues @@ -69,6 +68,27 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +def shortstr(iterable, maxitems=5): + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable (Iterable): iterable to truncate + maxitems (int): number of items to return before truncating + + Returns: + unicode + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + + class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -85,7 +105,7 @@ class FederationHandler(BaseHandler): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -114,9 +134,8 @@ class FederationHandler(BaseHandler): self._room_pdu_linearizer = Linearizer("fed_room_pdu") @defer.inlineCallbacks - @log_function def on_receive_pdu( - self, origin, pdu, get_missing=True, sent_to_us_directly=False, + self, origin, pdu, sent_to_us_directly=False, ): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -125,14 +144,23 @@ class FederationHandler(BaseHandler): origin (str): server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu (FrozenEvent): received PDU - get_missing (bool): True if we should fetch missing prev_events + sent_to_us_directly (bool): True if this event was pushed to us; False if + we pulled it as the result of a missing prev_event. Returns (Deferred): completes with None """ + room_id = pdu.room_id + event_id = pdu.event_id + + logger.info( + "[%s %s] handling received PDU: %s", + room_id, event_id, pdu, + ) + # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - pdu.event_id, + event_id, allow_none=True, allow_rejected=True, ) @@ -147,7 +175,7 @@ class FederationHandler(BaseHandler): ) ) if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) + logger.debug("[%s %s]: Already seen pdu", room_id, event_id) return # do some initial sanity-checking of the event. In particular, make @@ -156,6 +184,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: + logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) raise FederationError( "ERROR", err.code, @@ -165,10 +194,12 @@ class FederationHandler(BaseHandler): # If we are currently in the process of joining this room, then we # queue up events for later processing. - if pdu.room_id in self.room_queues: - logger.info("Ignoring PDU %s for room %s from %s for now; join " - "in progress", pdu.event_id, pdu.room_id, origin) - self.room_queues[pdu.room_id].append((pdu, origin)) + if room_id in self.room_queues: + logger.info( + "[%s %s] Queuing PDU from %s for now: join in progress", + room_id, event_id, origin, + ) + self.room_queues[room_id].append((pdu, origin)) return # If we're no longer in the room just ditch the event entirely. This @@ -179,7 +210,7 @@ class FederationHandler(BaseHandler): # we should check if we *are* in fact in the room. If we are then we # can magically rejoin the room. is_in_room = yield self.auth.check_host_in_room( - pdu.room_id, + room_id, self.server_name ) if not is_in_room: @@ -188,8 +219,8 @@ class FederationHandler(BaseHandler): ) if was_in_room: logger.info( - "Ignoring PDU %s for room %s from %s as we've left the room!", - pdu.event_id, pdu.room_id, origin, + "[%s %s] Ignoring PDU from %s as we've left the room", + room_id, event_id, origin, ) defer.returnValue(None) @@ -204,8 +235,8 @@ class FederationHandler(BaseHandler): ) logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth + "[%s %s] min_depth: %d", + room_id, event_id, min_depth, ) prevs = {e_id for e_id, _ in pdu.prev_events} @@ -218,17 +249,18 @@ class FederationHandler(BaseHandler): # send to the clients. pdu.internal_metadata.outlier = True elif min_depth and pdu.depth > min_depth: - if get_missing and prevs - seen: + missing_prevs = prevs - seen + if sent_to_us_directly and missing_prevs: # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( - "Acquiring lock for room %r to fetch %d missing events: %r...", - pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], + "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "Acquired lock for room %r to fetch %d missing events", - pdu.room_id, len(prevs - seen), + "[%s %s] Acquired room lock to fetch %d missing prev_events", + room_id, event_id, len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -241,69 +273,150 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( - "Found all missing prev events for %s", pdu.event_id + "[%s %s] Found all missing prev_events", + room_id, event_id, ) - elif prevs - seen: + elif missing_prevs: logger.info( - "Not fetching %d missing events for room %r,event %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, - list(prevs - seen)[:5], + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) - if sent_to_us_directly and prevs - seen: - # If they have sent it to us directly, and the server - # isn't telling us about the auth events that it's - # made a message referencing, we explode - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - elif prevs - seen: - # Calculate the state of the previous events, and - # de-conflict them to find the current state. - state_groups = [] + if prevs - seen: + # We've still not been able to get all of the prev_events for this event. + # + # In this case, we need to fall back to asking another server in the + # federation for the state at this event. That's ok provided we then + # resolve the state against other bits of the DAG before using it (which + # will ensure that you can't just take over a room by sending an event, + # withholding its prev_events, and declaring yourself to be an admin in + # the subsequent state request). + # + # Now, if we're pulling this event as a missing prev_event, then clearly + # this event is not going to become the only forward-extremity and we are + # guaranteed to resolve its state against our existing forward + # extremities, so that should be fine. + # + # On the other hand, if this event was pushed to us, it is possible for + # it to become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very bad. + # Further, if the event was pushed to us, there is no excuse for us not to + # have all the prev_events. We therefore reject any such events. + # + # XXX this really feels like it could/should be merged with the above, + # but there is an interaction with min_depth that I'm not really + # following. + + if sent_to_us_directly: + logger.warn( + "[%s %s] Failed to fetch %d prev events: rejecting", + room_id, event_id, len(prevs - seen), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. auth_chains = set() + event_map = { + event_id: pdu, + } try: # Get the state of the events we know about - ours = yield self.store.get_state_groups(pdu.room_id, list(seen)) - state_groups.append(ours) + ours = yield self.store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours # Ask the remote server for the states we don't # know about for p in prevs - seen: - state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, pdu.room_id, p - ) + logger.info( + "[%s %s] Requesting state at missing prev_event %s", + room_id, event_id, p, ) - auth_chains.update(got_auth_chain) - state_group = {(x.type, x.state_key): x.event_id for x in state} - state_groups.append(state_group) + + with logcontext.nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state, got_auth_chain = ( + yield self.federation_client.get_state_for_room( + origin, room_id, p, + ) + ) + + # we want the state *after* p; get_state_for_room returns the + # state *before* p. + remote_event = yield self.federation_client.get_pdu( + [origin], p, outlier=True, + ) + + if remote_event is None: + raise Exception( + "Unable to get missing prev_event %s" % (p, ) + ) + + if remote_event.is_state(): + remote_state.append(remote_event) + + # XXX hrm I'm not convinced that duplicate events will compare + # for equality, so I'm not sure this does what the author + # hoped. + auth_chains.update(got_auth_chain) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x # Resolve any conflicting state + @defer.inlineCallbacks def fetch(ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False + fetched = yield self.store.get_events( + ev_ids, get_prev_content=False, check_redacted=False, ) + # add any events we fetch here to the `event_map` so that we + # can use them to build the state event list below. + event_map.update(fetched) + defer.returnValue(fetched) - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = yield self.store.get_room_version(room_id) state_map = yield resolve_events_with_factory( - room_version, state_groups, {pdu.event_id: pdu}, fetch + room_version, state_maps, event_map, fetch, ) - state = (yield self.store.get_events(state_map.values())).values() + # we need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + state = [ + event_map[e] for e in six.itervalues(state_map) + ] auth_chain = list(auth_chains) except Exception: + logger.warn( + "[%s %s] Error attempting to resolve state at missing " + "prev_events", + room_id, event_id, exc_info=True, + ) raise FederationError( "ERROR", 403, "We can't get valid state history.", - affected=pdu.event_id, + affected=event_id, ) yield self._process_received_pdu( @@ -322,15 +435,16 @@ class FederationHandler(BaseHandler): prevs (set(str)): List of event ids which we are missing min_depth (int): Minimum depth of events to return. """ - # We recalculate seen, since it may have changed. + + room_id = pdu.room_id + event_id = pdu.event_id + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room( - pdu.room_id - ) + latest = yield self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -338,8 +452,8 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "Missing %d events for room %r pdu %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5] + "[%s %s]: Requesting %d prev_events: %s", + room_id, event_id, len(prevs - seen), shortstr(prevs - seen) ) # XXX: we set timeout to 10s to help workaround @@ -360,49 +474,88 @@ class FederationHandler(BaseHandler): # apparently. # # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out agressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the agressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timout to 60s and see what happens. missing_events = yield self.federation_client.get_missing_events( origin, - pdu.room_id, + room_id, earliest_events_ids=list(latest), latest_events=[pdu], limit=10, min_depth=min_depth, - timeout=10000, + timeout=60000, ) logger.info( - "Got %d events: %r...", - len(missing_events), [e.event_id for e in missing_events[:5]] + "[%s %s]: Got %d prev_events: %s", + room_id, event_id, len(missing_events), shortstr(missing_events), ) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) - for e in missing_events: - logger.info("Handling found event %s", e.event_id) - try: - yield self.on_receive_pdu( - origin, - e, - get_missing=False - ) - except FederationError as e: - if e.code == 403: - logger.warn("Event %s failed history check.") - else: - raise + for ev in missing_events: + logger.info( + "[%s %s] Handling received prev_event %s", + room_id, event_id, ev.event_id, + ) + with logcontext.nested_logging_context(ev.event_id): + try: + yield self.on_receive_pdu( + origin, + ev, + sent_to_us_directly=False, + ) + except FederationError as e: + if e.code == 403: + logger.warn( + "[%s %s] Received prev_event %s failed history check.", + room_id, event_id, ev.event_id, + ) + else: + raise - @log_function @defer.inlineCallbacks - def _process_received_pdu(self, origin, pdu, state, auth_chain): + def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ - event = pdu + room_id = event.room_id + event_id = event.event_id - logger.debug("Processing event: %s", event) + logger.debug( + "[%s %s] Processing event: %s", + room_id, event_id, event, + ) # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work @@ -411,15 +564,16 @@ class FederationHandler(BaseHandler): # event. if state and auth_chain and not event.internal_metadata.is_outlier(): is_in_room = yield self.auth.check_host_in_room( - event.room_id, + room_id, self.server_name ) else: is_in_room = True + if not is_in_room: logger.info( - "Got event for room we're not in: %r %r", - event.room_id, event.event_id + "[%s %s] Got event for room we're not in", + room_id, event_id, ) try: @@ -431,7 +585,7 @@ class FederationHandler(BaseHandler): "ERROR", e.code, e.msg, - affected=event.event_id, + affected=event_id, ) else: @@ -464,6 +618,10 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) + logger.info( + "[%s %s] persisting newly-received auth/state events %s", + room_id, event_id, [e["event"].event_id for e in event_infos] + ) yield self._handle_new_events(origin, event_infos) try: @@ -480,12 +638,12 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - room = yield self.store.get_room(event.room_id) + room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=event.room_id, + room_id=room_id, room_creator_user_id="", is_public=False, ) @@ -513,7 +671,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, event.room_id) + yield self.user_joined_room(user, room_id) @log_function @defer.inlineCallbacks @@ -594,7 +752,7 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + state_events.values() + auth_events.values() + for event in events + list(state_events.values()) + list(auth_events.values()) for a_id, _ in event.auth_events ) auth_events.update({ @@ -802,7 +960,7 @@ class FederationHandler(BaseHandler): ) continue except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: logger.info(e) @@ -1027,7 +1185,8 @@ class FederationHandler(BaseHandler): try: logger.info("Processing queued PDU %s which was received " "while we were joining %s", p.event_id, p.room_id) - yield self.on_receive_pdu(origin, p) + with logcontext.nested_logging_context(p.event_id): + yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( "Error handling queued PDU %s from %s: %s", @@ -1358,7 +1517,7 @@ class FederationHandler(BaseHandler): ) if state_groups: - _, state = state_groups.items().pop() + _, state = list(state_groups.items()).pop() results = state if event.is_state(): @@ -1430,12 +1589,10 @@ class FederationHandler(BaseHandler): else: defer.returnValue(None) - @log_function def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) @defer.inlineCallbacks - @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, backfilled=False): context = yield self._prep_event( @@ -1444,6 +1601,9 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: if not event.internal_metadata.is_outlier() and not backfilled: yield self.action_generator.handle_push_actions_for_event( @@ -1454,15 +1614,13 @@ class FederationHandler(BaseHandler): [(event, context)], backfilled=backfilled, ) - except: # noqa: E722, as we reraise the exception this is fine. - tp, value, tb = sys.exc_info() - - logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - - six.reraise(tp, value, tb) + success = True + finally: + if not success: + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) defer.returnValue(context) @@ -1475,15 +1633,22 @@ class FederationHandler(BaseHandler): Notifies about the events where appropriate. """ - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self._prep_event, + + @defer.inlineCallbacks + def prep(ev_info): + event = ev_info["event"] + with logcontext.nested_logging_context(suffix=event.event_id): + res = yield self._prep_event( origin, - ev_info["event"], + event, state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) + defer.returnValue(res) + + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) for ev_info in event_infos ], consumeErrors=True, )) @@ -1635,8 +1800,8 @@ class FederationHandler(BaseHandler): ) except AuthError as e: logger.warn( - "Rejecting %s because %s", - event.event_id, e.msg + "[%s %s] Rejecting: %s", + event.room_id, event.event_id, e.msg ) context.rejected = RejectedReason.AUTH_ERROR @@ -1831,7 +1996,7 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(event.room_id) - new_state = self.state_handler.resolve_events( + new_state = yield self.state_handler.resolve_events( room_version, [list(local_view.values()), list(remote_view.values())], event diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e484061cc0..4954b23a0d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -14,9 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys -import six from six import iteritems, itervalues, string_types from canonicaljson import encode_canonical_json, json @@ -624,6 +622,9 @@ class EventCreationHandler(object): event, context ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: # If we're a worker we need to hit out to the master. if self.config.worker_app: @@ -636,6 +637,7 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + success = True return yield self.persist_and_notify_client_event( @@ -645,17 +647,16 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) - except: # noqa: E722, as we reraise the exception this is fine. - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - tp, value, tb = sys.exc_info() - - run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - six.reraise(tp, value, tb) + success = True + finally: + if not success: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) @defer.inlineCallbacks def persist_and_notify_client_event( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5170d093e3..a155b6e938 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -269,14 +269,7 @@ class PaginationHandler(object): if state_ids: state = yield self.store.get_events(list(state_ids.values())) - - if state: - state = yield filter_events_for_client( - self.store, - user_id, - state.values(), - is_peeking=(member_event_id is None), - ) + state = state.values() time_now = self.clock.time_msec() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 75b8b7ce6a..1dfbde84fd 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -142,10 +142,8 @@ class BaseProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get displayname") raise - except Exception: - logger.exception("Failed to get displayname") - else: - defer.returnValue(result["displayname"]) + + defer.returnValue(result["displayname"]) @defer.inlineCallbacks def set_displayname(self, target_user, requester, new_displayname, by_admin=False): @@ -199,8 +197,6 @@ class BaseProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get avatar_url") raise - except Exception: - logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -278,7 +274,7 @@ class BaseProfileHandler(BaseHandler): except Exception as e: logger.warn( "Failed to update join event for room %s - %s", - room_id, str(e.message) + room_id, str(e) ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f03ee1476b..da914c46ff 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler): guest_access_token=None, make_guest=False, admin=False, + threepid=None, ): """Registers a new client on the server. @@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler): RegistrationError if there was a problem registering. """ - yield self.auth.check_auth_blocking() + yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -533,4 +534,5 @@ class RegistrationHandler(BaseHandler): room_id=room_id, remote_room_hosts=remote_room_hosts, action="join", + ratelimit=False, ) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 37e41afd61..38e1737ec9 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ r for r in sorted_rooms - if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] total_room_count = len(rooms_to_scan) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f643619047..07fd3e82fc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -583,6 +583,11 @@ class RoomMemberHandler(object): room_id = mapping["room_id"] servers = mapping["servers"] + # put the server which owns the alias at the front of the server list. + if room_alias.domain in servers: + servers.remove(room_alias.domain) + servers.insert(0, room_alias.domain) + defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index c464adbd0b..0c1d52fd11 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,7 +54,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch) + b = decode_base64(batch).decode('ascii') batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -258,18 +258,18 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( batch_group, batch_group_key, pagination_token - )) + )).encode('ascii')) else: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( "all", "", pagination_token - )) + )).encode('ascii')) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64("%s\n%s\n%s" % ( + group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( "room_id", room_id, pagination_token - )) + )).encode('ascii')) allowed_events.extend(room_events) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 648debc8aa..67b8ca28c7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user +from synapse.storage.roommember import MemberSummary from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -525,6 +526,8 @@ class SyncHandler(object): A deferred dict describing the room summary """ + # FIXME: we could/should get this from room_stats when matthew/stats lands + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1, @@ -537,44 +540,67 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, [ - (EventTypes.Member, None), (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), ] ) - member_ids = { - state_key: event_id - for (t, state_key), event_id in state_ids.iteritems() - if t == EventTypes.Member - } + # this is heavily cached, thus: fast. + details = yield self.store.get_room_summary(room_id) + name_id = state_ids.get((EventTypes.Name, '')) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) summary = {} - - # FIXME: it feels very heavy to load up every single membership event - # just to calculate the counts. - member_events = yield self.store.get_events(member_ids.values()) - - joined_user_ids = [] - invited_user_ids = [] - - for ev in member_events.values(): - if ev.content.get("membership") == Membership.JOIN: - joined_user_ids.append(ev.state_key) - elif ev.content.get("membership") == Membership.INVITE: - invited_user_ids.append(ev.state_key) + empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = len(joined_user_ids) - summary["m.invited_member_count"] = len(invited_user_ids) + summary["m.joined_member_count"] = ( + details.get(Membership.JOIN, empty_ms).count + ) + summary["m.invited_member_count"] = ( + details.get(Membership.INVITE, empty_ms).count + ) - if name_id or canonical_alias_id: - defer.returnValue(summary) + # if the room has a name or canonical_alias set, we can skip + # calculating heroes. we assume that if the event has contents, it'll + # be a valid name or canonical_alias - i.e. we're checking that they + # haven't been "deleted" by blatting {} over the top. + if name_id: + name = yield self.store.get_event(name_id, allow_none=True) + if name and name.content: + defer.returnValue(summary) + + if canonical_alias_id: + canonical_alias = yield self.store.get_event( + canonical_alias_id, allow_none=True, + ) + if canonical_alias and canonical_alias.content: + defer.returnValue(summary) + + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members + ] + gone_user_ids = ( + [r[0] for r in details.get(Membership.LEAVE, empty_ms).members] + + [r[0] for r in details.get(Membership.BAN, empty_ms).members] + ) - # FIXME: order by stream ordering, not alphabetic + # FIXME: only build up a member_ids list for our heroes + member_ids = {} + for membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.LEAVE, + Membership.BAN + ): + for user_id, event_id in details.get(membership, empty_ms).members: + member_ids[user_id] = event_id + # FIXME: order by stream ordering rather than as returned by SQL me = sync_config.user.to_string() if (joined_user_ids or invited_user_ids): summary['m.heroes'] = sorted( @@ -586,7 +612,11 @@ class SyncHandler(object): )[0:5] else: summary['m.heroes'] = sorted( - [user_id for user_id in member_ids.keys() if user_id != me] + [ + user_id + for user_id in gone_user_ids + if user_id != me + ] )[0:5] if not sync_config.filter_collection.lazy_load_members(): @@ -692,6 +722,13 @@ class SyncHandler(object): } if full_state: + if lazy_load_members: + # always make sure we LL ourselves so we know we're in the room + # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in #3840. + types.append((EventTypes.Member, sync_config.user.to_string())) + if batch: current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id, types=types, @@ -719,6 +756,26 @@ class SyncHandler(object): lazy_load_members=lazy_load_members, ) elif batch.limited: + state_at_timeline_start = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, types=types, + filtered_types=filtered_types, + ) + + # for now, we disable LL for gappy syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way + # more state in the server than if we were LLing. + # + # We still have to filter timeline_start to LL entries (above) in order + # for _calculate_state's LL logic to work, as we have to include LL + # members for timeline senders in case they weren't loaded in the initial + # sync. We do this by (counterintuitively) by filtering timeline_start + # members to just be ones which were timeline senders, which then ensures + # all of the rest get included in the state block (if we need to know + # about them). + types = None + filtered_types = None + state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token, types=types, filtered_types=filtered_types, @@ -729,25 +786,29 @@ class SyncHandler(object): filtered_types=filtered_types, ) - state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, - ) - state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, current=current_state_ids, + # we have to include LL members in case LL initial sync missed them lazy_load_members=lazy_load_members, ) else: state_ids = {} if lazy_load_members: - if types: + if types and batch.events: + # We're returning an incremental sync, with no + # "gap" since the previous sync, so normally there would be + # no state to return. + # But we're lazy-loading, so the client might need some more + # member events to understand the events in this timeline. + # So we fish out all the member events corresponding to the + # timeline here, and then dedupe any redundant ones below. + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id, types=types, - filtered_types=filtered_types, + filtered_types=None, # we only want members! ) if lazy_load_members and not include_redundant_members: @@ -767,7 +828,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in state_ids.iteritems() + for t, event_id in iteritems(state_ids) if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -1568,6 +1629,19 @@ class SyncHandler(object): newly_joined_room=newly_joined, ) + # When we join the room (or the client requests full_state), we should + # send down any existing tags. Usually the user won't have tags in a + # newly joined room, unless either a) they've joined before or b) the + # tag was added by synapse e.g. for server notice rooms. + if full_state: + user_id = sync_result_builder.sync_config.user.to_string() + tags = yield self.store.get_tags_for_room(user_id, room_id) + + # If there aren't any tags, don't send the empty tags list down + # sync + if not tags: + tags = None + account_data_events = [] if tags is not None: account_data_events.append({ @@ -1596,10 +1670,24 @@ class SyncHandler(object): ) summary = {} + + # we include a summary in room responses when we're lazy loading + # members (as the client otherwise doesn't have enough info to form + # the name itself). if ( sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. any(ev.type == EventTypes.Member for ev in batch.events) or + ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited and + any(t == EventTypes.Member for (t, k) in state) + ) or since_token is None ) ): @@ -1629,6 +1717,16 @@ class SyncHandler(object): unread_notifications["highlight_count"] = notifs["highlight_count"] sync_result_builder.joined.append(room_sync) + + if batch.limited and since_token: + user_id = sync_result_builder.sync_config.user.to_string() + logger.info( + "Incremental gappy sync of %s for user %s with %d state events" % ( + room_id, + user_id, + len(state), + ) + ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( room_id=room_id, @@ -1722,17 +1820,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - timeline_contains.items(), - previous.items(), - timeline_start.items(), - current.items(), + iteritems(timeline_contains), + iteritems(previous), + iteritems(timeline_start), + iteritems(current), ) } - c_ids = set(e for e in current.values()) - ts_ids = set(e for e in timeline_start.values()) - p_ids = set(e for e in previous.values()) - tc_ids = set(e for e in timeline_contains.values()) + c_ids = set(e for e in itervalues(current)) + ts_ids = set(e for e in itervalues(timeline_start)) + p_ids = set(e for e in itervalues(previous)) + tc_ids = set(e for e in itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -1746,7 +1844,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in timeline_start.iteritems() + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 2d2d3d5a0d..c610933dd4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID, get_domain_from_id +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -68,6 +69,11 @@ class TypingHandler(object): # map room IDs to sets of users currently typing self._room_typing = {} + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial, + ) + self.clock.looping_call( self._handle_timeouts, 5000, @@ -218,6 +224,7 @@ class TypingHandler(object): for domain in set(get_domain_from_id(u) for u in users): if domain != self.server_name: + logger.debug("sending typing update to %s", domain) self.federation.send_edu( destination=domain, edu_type="m.typing", @@ -274,19 +281,29 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial + self._typing_stream_change_cache.entity_has_changed( + member.room_id, self._latest_room_serial, + ) self.notifier.on_new_event( "typing_key", self._latest_room_serial, rooms=[member.room_id] ) def get_all_typing_updates(self, last_id, current_id): - # TODO: Work out a way to do this without scanning the entire state. if last_id == current_id: return [] + changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( + last_id, + ) + + if changed_rooms is None: + changed_rooms = self._room_serials + rows = [] - for room_id, serial in self._room_serials.items(): - if last_id < serial and serial <= current_id: + for room_id in changed_rooms: + serial = self._room_serials[room_id] + if last_id < serial <= current_id: typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 58ef8d3ce4..a3f9e4f67c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -38,12 +38,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') def redact_uri(uri): """Strips access tokens from the uri replaces with <redacted>""" return ACCESS_TOKEN_RE.sub( - br'\1<redacted>\3', + r'\1<redacted>\3', uri ) diff --git a/synapse/http/client.py b/synapse/http/client.py index ab4fbf59b2..3d05f83b8c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,24 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging -import urllib -from six import StringIO +from six import text_type +from six.moves import urllib +import treq from canonicaljson import encode_canonical_json, json from prometheus_client import Counter from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl, task +from twisted.internet import defer, protocol, reactor, ssl from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone from twisted.web.client import ( Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, - FileBodyProducer as TwistedFileBodyProducer, GzipDecoder, HTTPConnectionPool, PartialDownloadError, @@ -42,7 +43,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async_helpers import add_timeout_to_deferred +from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable @@ -83,8 +84,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + self.user_agent = self.user_agent.encode('ascii') + @defer.inlineCallbacks - def request(self, method, uri, *args, **kwargs): + def request(self, method, uri, data=b'', headers=None): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -93,10 +96,10 @@ class SimpleHttpClient(object): logger.info("Sending request %s %s", method, redact_uri(uri)) try: - request_deferred = self.agent.request( - method, uri, *args, **kwargs + request_deferred = treq.request( + method, uri, agent=self.agent, data=data, headers=headers ) - add_timeout_to_deferred( + request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) @@ -112,7 +115,7 @@ class SimpleHttpClient(object): incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.message + method, redact_uri(uri), type(e).__name__, e.args[0] ) raise @@ -137,7 +140,8 @@ class SimpleHttpClient(object): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode( + encode_urlencode_args(args), True).encode("utf8") actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -148,15 +152,14 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(query_bytes)) + data=query_bytes ) - body = yield make_deferred_yieldable(readBody(response)) - if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + body = yield make_deferred_yieldable(treq.json_content(response)) + defer.returnValue(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -191,9 +194,9 @@ class SimpleHttpClient(object): response = yield self.request( "POST", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -248,7 +251,7 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) json_str = encode_canonical_json(json_body) @@ -262,9 +265,9 @@ class SimpleHttpClient(object): response = yield self.request( "PUT", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), - bodyProducer=FileBodyProducer(StringIO(json_str)) + data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -293,7 +296,7 @@ class SimpleHttpClient(object): HttpResponseException on a non-2xx HTTP response. """ if len(args): - query_bytes = urllib.urlencode(args, True) + query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) actual_headers = { @@ -304,7 +307,7 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - uri.encode("ascii"), + uri, headers=Headers(actual_headers), ) @@ -339,13 +342,14 @@ class SimpleHttpClient(object): response = yield self.request( "GET", - url.encode("ascii"), + url, headers=Headers(actual_headers), ) resp_headers = dict(response.headers.getAllRawHeaders()) - if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: + if (b'Content-Length' in resp_headers and + int(resp_headers[b'Content-Length']) > max_size): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -378,7 +382,12 @@ class SimpleHttpClient(object): ) defer.returnValue( - (length, resp_headers, response.request.absoluteURI, response.code), + ( + length, + resp_headers, + response.request.absoluteURI.decode('ascii'), + response.code, + ), ) @@ -434,12 +443,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): @defer.inlineCallbacks def post_urlencoded_get_raw(self, url, args={}): - query_bytes = urllib.urlencode(encode_urlencode_args(args), True) + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True) response = yield self.request( "POST", - url.encode("ascii"), - bodyProducer=FileBodyProducer(StringIO(query_bytes)), + url, + data=query_bytes, headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], @@ -463,9 +472,9 @@ class SpiderEndpointFactory(object): def endpointForURI(self, uri): logger.info("Getting endpoint for %s", uri.toBytes()) - if uri.scheme == "http": + if uri.scheme == b"http": endpoint_factory = HostnameEndpoint - elif uri.scheme == "https": + elif uri.scheme == b"https": tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) def endpoint_factory(reactor, host, port, **kw): @@ -510,7 +519,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, unicode): + if isinstance(arg, text_type): return arg.encode('utf-8') elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] @@ -542,26 +551,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): def creatorForNetloc(self, hostname, port): return self - - -class FileBodyProducer(TwistedFileBodyProducer): - """Workaround for https://twistedmatrix.com/trac/ticket/8473 - - We override the pauseProducing and resumeProducing methods in twisted's - FileBodyProducer so that they do not raise exceptions if the task has - already completed. - """ - - def pauseProducing(self): - try: - super(FileBodyProducer, self).pauseProducing() - except task.TaskDone: - # task has already completed - pass - - def resumeProducing(self): - try: - super(FileBodyProducer, self).resumeProducing() - except task.NotPaused: - # task was not paused (probably because it had already completed) - pass diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index b0c9369519..91025037a3 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -108,7 +108,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= Args: reactor: Twisted reactor. - destination (bytes): The name of the server to connect to. + destination (unicode): The name of the server to connect to. tls_client_options_factory (synapse.crypto.context_factory.ClientTLSOptionsFactory): Factory which generates TLS options for client connections. @@ -126,10 +126,17 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= transport_endpoint = HostnameEndpoint default_port = 8008 else: + # the SNI string should be the same as the Host header, minus the port. + # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777, + # the Host header and SNI should therefore be the server_name of the remote + # server. + tls_options = tls_client_options_factory.get_options(domain) + def transport_endpoint(reactor, host, port, timeout): return wrapClientTLS( - tls_client_options_factory.get_options(host), - HostnameEndpoint(reactor, host, port, timeout=timeout)) + tls_options, + HostnameEndpoint(reactor, host, port, timeout=timeout), + ) default_port = 8448 if port is None: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b34bb8e31a..14b12cd1c4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,19 +17,22 @@ import cgi import logging import random import sys -import urllib +from io import BytesIO -from six import string_types -from six.moves.urllib import parse as urlparse +from six import PY3, string_types +from six.moves import urllib -from canonicaljson import encode_canonical_json, json +import attr +import treq +from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from twisted.internet import defer, protocol, reactor +from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError +from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool from twisted.web.http_headers import Headers import synapse.metrics @@ -40,14 +43,12 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) -from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util import logcontext -from synapse.util.async_helpers import add_timeout_to_deferred +from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable +from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outbound_logger = logging.getLogger("synapse.http.outbound") outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", "", ["method"]) @@ -58,20 +59,119 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 +if PY3: + MAXINT = sys.maxsize +else: + MAXINT = sys.maxint + class MatrixFederationEndpointFactory(object): def __init__(self, hs): + self.reactor = hs.get_reactor() self.tls_client_options_factory = hs.tls_client_options_factory def endpointForURI(self, uri): - destination = uri.netloc + destination = uri.netloc.decode('ascii') return matrix_federation_endpoint( - reactor, destination, timeout=10, + self.reactor, destination, timeout=10, tls_client_options_factory=self.tls_client_options_factory ) +_next_id = 1 + + +@attr.s +class MatrixFederationRequest(object): + method = attr.ib() + """HTTP method + :type: str + """ + + path = attr.ib() + """HTTP path + :type: str + """ + + destination = attr.ib() + """The remote server to send the HTTP request to. + :type: str""" + + json = attr.ib(default=None) + """JSON to send in the body. + :type: dict|None + """ + + json_callback = attr.ib(default=None) + """A callback to generate the JSON. + :type: func|None + """ + + query = attr.ib(default=None) + """Query arguments. + :type: dict|None + """ + + txn_id = attr.ib(default=None) + """Unique ID for this request (for logging) + :type: str|None + """ + + def __attrs_post_init__(self): + global _next_id + self.txn_id = "%s-O-%s" % (self.method, _next_id) + _next_id = (_next_id + 1) % (MAXINT - 1) + + def get_json(self): + if self.json_callback: + return self.json_callback() + return self.json + + +@defer.inlineCallbacks +def _handle_json_response(reactor, timeout_sec, request, response): + """ + Reads the JSON body of a response, with a timeout + + Args: + reactor (IReactor): twisted reactor, for the timeout + timeout_sec (float): number of seconds to wait for response to complete + request (MatrixFederationRequest): the request that triggered the response + response (IResponse): response to the request + + Returns: + dict: parsed JSON response + """ + try: + check_content_type_is_json(response.headers) + + d = treq.json_content(response) + d = timeout_deferred( + d, + timeout=timeout_sec, + reactor=reactor, + ) + + body = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) + raise + logger.info( + "{%s} [%s] Completed: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) + defer.returnValue(body) + + class MatrixFederationHttpClient(object): """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. @@ -85,7 +185,9 @@ class MatrixFederationHttpClient(object): self.hs = hs self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname + reactor = hs.get_reactor() pool = HTTPConnectionPool(reactor) + pool.retryAutomatically = False pool.maxPersistentPerHost = 5 pool.cachedConnectionTimeout = 2 * 60 self.agent = Agent.usingEndpointFactory( @@ -93,28 +195,36 @@ class MatrixFederationHttpClient(object): ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string = hs.version_string - self._next_id = 1 + self.version_string = hs.version_string.encode('ascii') + self.default_timeout = 60 - def _create_url(self, destination, path_bytes, param_bytes, query_bytes): - return urlparse.urlunparse( - ("matrix", destination, path_bytes, param_bytes, query_bytes, "") - ) + def schedule(x): + reactor.callLater(_EPSILON, x) + + self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks - def _request(self, destination, method, path, - body_callback, headers_dict={}, param_bytes=b"", - query_bytes=b"", retry_on_dns_fail=True, - timeout=None, long_retries=False, - ignore_backoff=False, - backoff_on_404=False): - """ Creates and sends a request to the given server + def _send_request( + self, + request, + retry_on_dns_fail=True, + timeout=None, + long_retries=False, + ignore_backoff=False, + backoff_on_404=False + ): + """ + Sends a request to the given server. + Args: - destination (str): The remote server to send the HTTP request to. - method (str): HTTP method - path (str): The HTTP path + request (MatrixFederationRequest): details of request to be sent + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server). 60s by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + backoff_on_404 (bool): Back off if we get a 404 Returns: @@ -132,38 +242,39 @@ class MatrixFederationHttpClient(object): (May also fail with plenty of other Exceptions for things like DNS failures, connection failures, SSL failures.) """ + if timeout: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + if ( self.hs.config.federation_domain_whitelist is not None and - destination not in self.hs.config.federation_domain_whitelist + request.destination not in self.hs.config.federation_domain_whitelist ): - raise FederationDeniedError(destination) + raise FederationDeniedError(request.destination) limiter = yield synapse.util.retryutils.get_retry_limiter( - destination, + request.destination, self.clock, self._store, backoff_on_404=backoff_on_404, ignore_backoff=ignore_backoff, ) - destination = destination.encode("ascii") - path_bytes = path.encode("ascii") - with limiter: - headers_dict[b"User-Agent"] = [self.version_string] - headers_dict[b"Host"] = [destination] - - url_bytes = self._create_url( - destination, path_bytes, param_bytes, query_bytes - ) - - txn_id = "%s-O-%s" % (method, self._next_id) - self._next_id = (self._next_id + 1) % (sys.maxint - 1) + method = request.method + destination = request.destination + path_bytes = request.path.encode("ascii") + if request.query: + query_bytes = encode_query_args(request.query) + else: + query_bytes = b"" - outbound_logger.info( - "{%s} [%s] Sending request: %s %s", - txn_id, destination, method, url_bytes - ) + headers_dict = { + "User-Agent": [self.version_string], + "Host": [request.destination], + } + with limiter: # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) if long_retries: @@ -171,88 +282,119 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - http_url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "") - ) - - log_result = None - try: - while True: - producer = None - if body_callback: - producer = body_callback(method, http_url_bytes, headers_dict) - - try: - request_deferred = self.agent.request( - method, - url_bytes, - Headers(headers_dict), - producer + url = urllib.parse.urlunparse(( + b"matrix", destination.encode("ascii"), + path_bytes, None, query_bytes, b"", + )).decode('ascii') + + http_url = urllib.parse.urlunparse(( + b"", b"", + path_bytes, None, query_bytes, b"", + )).decode('ascii') + + while True: + try: + json = request.get_json() + if json: + data = encode_canonical_json(json) + headers_dict["Content-Type"] = ["application/json"] + self.sign_request( + destination, method, http_url, headers_dict, json ) - add_timeout_to_deferred( - request_deferred, - timeout / 1000. if timeout else 60, - self.hs.get_reactor(), - cancelled_to_request_timed_out_error, + else: + data = None + self.sign_request(destination, method, http_url, headers_dict) + + logger.info( + "{%s} [%s] Sending request: %s %s", + request.txn_id, destination, method, url + ) + + if data: + producer = FileBodyProducer( + BytesIO(data), + cooperator=self._cooperator ) + else: + producer = None + + request_deferred = treq.request( + method, + url, + headers=Headers(headers_dict), + data=producer, + agent=self.agent, + reactor=self.hs.get_reactor(), + unbuffered=True + ) + + request_deferred = timeout_deferred( + request_deferred, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + + with Measure(self.clock, "outbound_request"): response = yield make_deferred_yieldable( request_deferred, ) - log_result = "%d %s" % (response.code, response.phrase,) - break - except Exception as e: - if not retry_on_dns_fail and isinstance(e, DNSLookupError): - logger.warn( - "DNS Lookup failed to %s with %s", - destination, - e - ) - log_result = "DNS Lookup failed to %s with %s" % ( - destination, e - ) - raise - - logger.warn( - "{%s} Sending request failed to %s: %s %s: %s", - txn_id, + break + except Exception as e: + logger.warn( + "{%s} [%s] Request failed: %s %s: %s", + request.txn_id, + destination, + method, + url, + _flatten_response_never_received(e), + ) + + if not retry_on_dns_fail and isinstance(e, DNSLookupError): + raise + + if retries_left and not timeout: + if long_retries: + delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) + delay = min(delay, 60) + delay *= random.uniform(0.8, 1.4) + else: + delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) + delay = min(delay, 2) + delay *= random.uniform(0.8, 1.4) + + logger.debug( + "{%s} [%s] Waiting %ss before re-sending...", + request.txn_id, destination, - method, - url_bytes, - _flatten_response_never_received(e), + delay, ) - log_result = _flatten_response_never_received(e) - - if retries_left and not timeout: - if long_retries: - delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) - delay = min(delay, 60) - delay *= random.uniform(0.8, 1.4) - else: - delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) - delay = min(delay, 2) - delay *= random.uniform(0.8, 1.4) - - yield self.clock.sleep(delay) - retries_left -= 1 - else: - raise - finally: - outbound_logger.info( - "{%s} [%s] Result: %s", - txn_id, - destination, - log_result, - ) + yield self.clock.sleep(delay) + retries_left -= 1 + else: + raise + + logger.info( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) if 200 <= response.code < 300: pass else: # :'( # Update transactions table? - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + d = treq.content(response) + d = timeout_deferred( + d, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + body = yield make_deferred_yieldable(d) raise HttpResponseException( response.code, response.phrase, body ) @@ -297,11 +439,11 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(bytes( + auth_headers.append(( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( self.server_name, key, sig, - ) - )) + )).encode('ascii') + ) headers_dict[b"Authorization"] = auth_headers @@ -346,38 +488,27 @@ class MatrixFederationHttpClient(object): is not on our federation whitelist """ - if not json_data_callback: - def json_data_callback(): - return data + request = MatrixFederationRequest( + method="PUT", + destination=destination, + path=path, + query=args, + json_callback=json_data_callback, + json=data, + ) - def body_callback(method, url_bytes, headers_dict): - json_data = json_data_callback() - self.sign_request( - destination, method, url_bytes, headers_dict, json_data - ) - producer = _JsonProducer(json_data) - return producer - - response = yield self._request( - destination, - "PUT", - path, - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, - query_bytes=encode_query_args(args), + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, backoff_on_404=backoff_on_404, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def post_json(self, destination, path, data={}, long_retries=False, @@ -411,32 +542,30 @@ class MatrixFederationHttpClient(object): is not on our federation whitelist """ - def body_callback(method, url_bytes, headers_dict): - self.sign_request( - destination, method, url_bytes, headers_dict, data - ) - return _JsonProducer(data) - - response = yield self._request( - destination, - "POST", - path, - query_bytes=encode_query_args(args), - body_callback=body_callback, - headers_dict={"Content-Type": ["application/json"]}, + request = MatrixFederationRequest( + method="POST", + destination=destination, + path=path, + query=args, + json=data, + ) + + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) + if timeout: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), _sec_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def get_json(self, destination, path, args=None, retry_on_dns_fail=True, @@ -471,29 +600,24 @@ class MatrixFederationHttpClient(object): logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, + query=args, + ) - response = yield self._request( - destination, - "GET", - path, - query_bytes=encode_query_args(args), - body_callback=body_callback, + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def delete_json(self, destination, path, long_retries=False, @@ -523,26 +647,24 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ + request = MatrixFederationRequest( + method="DELETE", + destination=destination, + path=path, + query=args, + ) - response = yield self._request( - destination, - "DELETE", - path, - query_bytes=encode_query_args(args), - headers_dict={"Content-Type": ["application/json"]}, + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - body = yield readBody(response) - - defer.returnValue(json.loads(body)) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) + defer.returnValue(body) @defer.inlineCallbacks def get_file(self, destination, path, output_stream, args={}, @@ -569,26 +691,15 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, + query=args, + ) - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, string_types): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.urlencode(encoded_args, True) - logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail) - - def body_callback(method, url_bytes, headers_dict): - self.sign_request(destination, method, url_bytes, headers_dict) - return None - - response = yield self._request( - destination, - "GET", - path, - query_bytes=query_bytes, - body_callback=body_callback, + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -596,14 +707,25 @@ class MatrixFederationHttpClient(object): headers = dict(response.headers.getAllRawHeaders()) try: - with logcontext.PreserveLoggingContext(): - length = yield _readBodyToFile( - response, output_stream, max_size - ) - except Exception: - logger.exception("Failed to download body") + d = _readBodyToFile(response, output_stream, max_size) + d.addTimeout(self.default_timeout, self.hs.get_reactor()) + length = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) raise - + logger.info( + "{%s} [%s] Completed: %d %s [%d bytes]", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + length, + ) defer.returnValue((length, headers)) @@ -639,30 +761,6 @@ def _readBodyToFile(response, stream, max_size): return d -class _JsonProducer(object): - """ Used by the twisted http client to create the HTTP body from json - """ - def __init__(self, jsn): - self.reset(jsn) - - def reset(self, jsn): - self.body = encode_canonical_json(jsn) - self.length = len(self.body) - - def startProducing(self, consumer): - consumer.write(self.body) - return defer.succeed(None) - - def pauseProducing(self): - pass - - def stopProducing(self): - pass - - def resumeProducing(self): - pass - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -693,7 +791,7 @@ def check_content_type_is_json(headers): "No Content-Type header" ) - c_type = c_type[0] # only the first header + c_type = c_type[0].decode('ascii') # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RuntimeError( @@ -711,6 +809,6 @@ def encode_query_args(args): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] - query_bytes = urllib.urlencode(encoded_args, True) + query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes + return query_bytes.encode('utf8') diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 72c2654678..fedb4e6b18 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -162,7 +162,7 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.add(self) - def stop(self, time_sec, request): + def stop(self, time_sec, response_code, sent_bytes): with _in_flight_requests_lock: _in_flight_requests.discard(self) @@ -179,35 +179,35 @@ class RequestMetrics(object): ) return - response_code = str(request.code) + response_code = str(response_code) - outgoing_responses_counter.labels(request.method, response_code).inc() + outgoing_responses_counter.labels(self.method, response_code).inc() - response_count.labels(request.method, self.name, tag).inc() + response_count.labels(self.method, self.name, tag).inc() - response_timer.labels(request.method, self.name, tag, response_code).observe( + response_timer.labels(self.method, self.name, tag, response_code).observe( time_sec - self.start ) resource_usage = context.get_resource_usage() - response_ru_utime.labels(request.method, self.name, tag).inc( + response_ru_utime.labels(self.method, self.name, tag).inc( resource_usage.ru_utime, ) - response_ru_stime.labels(request.method, self.name, tag).inc( + response_ru_stime.labels(self.method, self.name, tag).inc( resource_usage.ru_stime, ) - response_db_txn_count.labels(request.method, self.name, tag).inc( + response_db_txn_count.labels(self.method, self.name, tag).inc( resource_usage.db_txn_count ) - response_db_txn_duration.labels(request.method, self.name, tag).inc( + response_db_txn_duration.labels(self.method, self.name, tag).inc( resource_usage.db_txn_duration_sec ) - response_db_sched_duration.labels(request.method, self.name, tag).inc( + response_db_sched_duration.labels(self.method, self.name, tag).inc( resource_usage.db_sched_duration_sec ) - response_size.labels(request.method, self.name, tag).inc(request.sentLength) + response_size.labels(self.method, self.name, tag).inc(sent_bytes) # We always call this at the end to ensure that we update the metrics # regardless of whether a call to /metrics while the request was in diff --git a/synapse/http/server.py b/synapse/http/server.py index 2d5c23e673..b4b25cab19 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -84,10 +84,21 @@ def wrap_json_request_handler(h): logger.info( "%s SynapseError: %s - %s", request, code, e.msg ) - respond_with_json( - request, code, e.error_dict(), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, code, e.error_dict(), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) except Exception: # failure.Failure() fishes the original Failure out @@ -100,16 +111,26 @@ def wrap_json_request_handler(h): request, f.getTraceback().rstrip(), ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) return wrap_async_request_handler(wrapped_request_handler) diff --git a/synapse/http/site.py b/synapse/http/site.py index 88ed3714f9..e508c0bd4f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -75,17 +75,35 @@ class SynapseRequest(Request): return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( self.__class__.__name__, id(self), - self.method, + self.get_method(), self.get_redacted_uri(), - self.clientproto, + self.clientproto.decode('ascii', errors='replace'), self.site.site_tag, ) def get_request_id(self): - return "%s-%i" % (self.method, self.request_seq) + return "%s-%i" % (self.get_method(), self.request_seq) def get_redacted_uri(self): - return redact_uri(self.uri) + uri = self.uri + if isinstance(uri, bytes): + uri = self.uri.decode('ascii') + return redact_uri(uri) + + def get_method(self): + """Gets the method associated with the request (or placeholder if not + method has yet been received). + + Note: This is necessary as the placeholder value in twisted is str + rather than bytes, so we need to sanitise `self.method`. + + Returns: + str + """ + method = self.method + if isinstance(method, bytes): + method = self.method.decode('ascii') + return method def get_user_agent(self): return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] @@ -116,7 +134,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.method, + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager @@ -204,14 +222,14 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.method, + self.start_time, name=servlet_name, method=self.get_method(), ) self.site.access_logger.info( "%s - %s - Received request: %s %s", self.getClientIP(), self.site.site_tag, - self.method, + self.get_method(), self.get_redacted_uri() ) @@ -277,15 +295,15 @@ class SynapseRequest(Request): int(usage.db_txn_count), self.sentLength, code, - self.method, + self.get_method(), self.get_redacted_uri(), - self.clientproto, + self.clientproto.decode('ascii', errors='replace'), user_agent, usage.evt_db_fetch_count, ) try: - self.request_metrics.stop(self.finish_time, self) + self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: logger.warn("Failed to stop metrics: %r", e) @@ -305,7 +323,7 @@ class XForwardedForRequest(SynapseRequest): C{b"-"}. """ return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() + b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') class SynapseRequestFactory(object): diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 550f8443f7..59900aa5d1 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -18,8 +18,11 @@ import gc import logging import os import platform +import threading import time +import six + import attr from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import REGISTRY, GaugeMetricFamily @@ -68,7 +71,7 @@ class LaterGauge(object): return if isinstance(calls, dict): - for k, v in calls.items(): + for k, v in six.iteritems(calls): g.add_metric(k, v) else: g.add_metric([], calls) @@ -87,6 +90,109 @@ class LaterGauge(object): all_gauges[self.name] = self +class InFlightGauge(object): + """Tracks number of things (e.g. requests, Measure blocks, etc) in flight + at any given time. + + Each InFlightGauge will create a metric called `<name>_total` that counts + the number of in flight blocks, as well as a metrics for each item in the + given `sub_metrics` as `<name>_<sub_metric>` which will get updated by the + callbacks. + + Args: + name (str) + desc (str) + labels (list[str]) + sub_metrics (list[str]): A list of sub metrics that the callbacks + will update. + """ + + def __init__(self, name, desc, labels, sub_metrics): + self.name = name + self.desc = desc + self.labels = labels + self.sub_metrics = sub_metrics + + # Create a class which have the sub_metrics values as attributes, which + # default to 0 on initialization. Used to pass to registered callbacks. + self._metrics_class = attr.make_class( + "_MetricsEntry", + attrs={x: attr.ib(0) for x in sub_metrics}, + slots=True, + ) + + # Counts number of in flight blocks for a given set of label values + self._registrations = {} + + # Protects access to _registrations + self._lock = threading.Lock() + + self._register_with_collector() + + def register(self, key, callback): + """Registers that we've entered a new block with labels `key`. + + `callback` gets called each time the metrics are collected. The same + value must also be given to `unregister`. + + `callback` gets called with an object that has an attribute per + sub_metric, which should be updated with the necessary values. Note that + the metrics object is shared between all callbacks registered with the + same key. + + Note that `callback` may be called on a separate thread. + """ + with self._lock: + self._registrations.setdefault(key, set()).add(callback) + + def unregister(self, key, callback): + """Registers that we've exited a block with labels `key`. + """ + + with self._lock: + self._registrations.setdefault(key, set()).discard(callback) + + def collect(self): + """Called by prometheus client when it reads metrics. + + Note: may be called by a separate thread. + """ + in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels) + + metrics_by_key = {} + + # We copy so that we don't mutate the list while iterating + with self._lock: + keys = list(self._registrations) + + for key in keys: + with self._lock: + callbacks = set(self._registrations[key]) + + in_flight.add_metric(key, len(callbacks)) + + metrics = self._metrics_class() + metrics_by_key[key] = metrics + for callback in callbacks: + callback(metrics) + + yield in_flight + + for name in self.sub_metrics: + gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels) + for key, metrics in six.iteritems(metrics_by_key): + gauge.add_metric(key, getattr(metrics, name)) + yield gauge + + def _register_with_collector(self): + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering" % (self.name,)) + REGISTRY.unregister(all_gauges.pop(self.name)) + + REGISTRY.register(self) + all_gauges[self.name] = self + + # # Detailed CPU metrics # diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 167167be0a..037f1c490e 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading import six @@ -23,6 +24,9 @@ from twisted.internet import defer from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +logger = logging.getLogger(__name__) + + _background_process_start_count = Counter( "synapse_background_process_start_count", "Number of background processes started", @@ -97,9 +101,13 @@ class _Collector(object): labels=["name"], ) - # We copy the dict so that it doesn't change from underneath us + # We copy the dict so that it doesn't change from underneath us. + # We also copy the process lists as that can also change with _bg_metrics_lock: - _background_processes_copy = dict(_background_processes) + _background_processes_copy = { + k: list(v) + for k, v in six.iteritems(_background_processes) + } for desc, processes in six.iteritems(_background_processes_copy): background_process_in_flight_count.add_metric( @@ -191,6 +199,8 @@ def run_as_background_process(desc, func, *args, **kwargs): try: yield func(*args, **kwargs) + except Exception: + logger.exception("Background process '%s' threw an exception", desc) finally: proc.update_metrics() diff --git a/synapse/notifier.py b/synapse/notifier.py index 82f391481c..340b16ce25 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -24,13 +24,10 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import StreamToken -from synapse.util.async_helpers import ( - DeferredTimeoutError, - ObservableDeferred, - add_timeout_to_deferred, -) -from synapse.util.logcontext import PreserveLoggingContext, run_in_background +from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -252,7 +249,10 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - run_in_background(self._notify_app_services, room_stream_id) + run_as_background_process( + "notify_app_services", + self._notify_app_services, room_stream_id, + ) if self.federation_sender: self.federation_sender.notify_new_events(room_stream_id) @@ -337,7 +337,7 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) - add_timeout_to_deferred( + listener.deferred = timeout_deferred( listener.deferred, (end_time - now) / 1000., self.hs.get_reactor(), @@ -354,7 +354,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break @@ -559,15 +559,16 @@ class Notifier(object): if end_time <= now: break - add_timeout_to_deferred( - listener.deferred.addTimeout, - (end_time - now) / 1000., - self.hs.get_reactor(), + listener.deferred = timeout_deferred( + listener.deferred, + timeout=(end_time - now) / 1000., + reactor=self.hs.get_reactor(), ) + try: with PreserveLoggingContext(): yield listener.deferred - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 81e18bcf7d..48abd5e4d6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,6 +15,8 @@ # limitations under the License. import logging +import six + from prometheus_client import Counter from twisted.internet import defer @@ -26,6 +28,9 @@ from synapse.util.metrics import Measure from . import push_rule_evaluator, push_tools +if six.PY3: + long = int + logger = logging.getLogger(__name__) http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "") @@ -96,7 +101,7 @@ class HttpPusher(object): @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) yield self._process() @defer.inlineCallbacks diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index bfa6df7b68..1a5a10d974 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -17,10 +17,11 @@ import email.mime.multipart import email.utils import logging import time -import urllib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from six.moves import urllib + import bleach import jinja2 @@ -440,7 +441,7 @@ class Mailer(object): def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: - base_url = self.hs.config.email_riot_base_url + base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS base_url = "https://vector.im/beta/#/room" @@ -474,7 +475,7 @@ class Mailer(object): # XXX: make r0 once API is stable return "%s_matrix/client/unstable/pushers/remove?%s" % ( self.hs.config.public_baseurl, - urllib.urlencode(params), + urllib.parse.urlencode(params), ) @@ -561,7 +562,7 @@ def _create_mxc_to_http_filter(config): return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( config.public_baseurl, serverAndMediaId, - urllib.urlencode(params), + urllib.parse.urlencode(params), fragment or "", ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 9c55e79ef5..d4d983b00a 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -33,31 +33,38 @@ logger = logging.getLogger(__name__) # [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], - "frozendict>=0.4": ["frozendict"], + "frozendict>=1": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], - "service_identity>=1.0.0": ["service_identity>=1.0.0"], + "service_identity>=16.0.0": ["service_identity>=16.0.0"], "Twisted>=17.1.0": ["twisted>=17.1.0"], - - # We use crypto.get_elliptic_curve which is only supported in >=0.15 - "pyopenssl>=0.15": ["OpenSSL>=0.15"], - - "pyyaml": ["yaml"], - "pyasn1": ["pyasn1"], - "daemonize": ["daemonize"], - "bcrypt": ["bcrypt>=3.1.0"], - "pillow": ["PIL"], - "pydenticon": ["pydenticon"], - "sortedcontainers": ["sortedcontainers"], - "pysaml2>=3.0.0": ["saml2>=3.0.0"], - "pymacaroons-pynacl": ["pymacaroons"], + "treq>=15.1": ["treq>=15.1"], + + # Twisted has required pyopenssl 16.0 since about Twisted 16.6. + "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"], + + "pyyaml>=3.11": ["yaml"], + "pyasn1>=0.1.9": ["pyasn1"], + "pyasn1-modules>=0.0.7": ["pyasn1_modules"], + "daemonize>=2.3.1": ["daemonize"], + "bcrypt>=3.1.0": ["bcrypt>=3.1.0"], + "pillow>=3.1.2": ["PIL"], + "pydenticon>=0.2": ["pydenticon"], + "sortedcontainers>=1.4.4": ["sortedcontainers"], + "pysaml2>=3.0.0": ["saml2"], + "pymacaroons-pynacl>=0.9.3": ["pymacaroons"], "msgpack-python>=0.3.0": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], - "six": ["six"], - "prometheus_client": ["prometheus_client"], - "attrs": ["attr"], + "six>=1.10": ["six"], + + # prometheus_client 0.4.0 changed the format of counter metrics + # (cf https://github.com/matrix-org/synapse/issues/4001) + "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"], + + # we use attr.s(slots), which arrived in 16.0.0 + "attrs>=16.0.0": ["attr>=16.0.0"], "netaddr>=0.7.18": ["netaddr"], } @@ -78,6 +85,9 @@ CONDITIONAL_REQUIREMENTS = { "affinity": { "affinity": ["affinity"], }, + "postgres": { + "psycopg2>=2.6": ["psycopg2"] + } } diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 8206a988f7..21b8c468fa 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import six + from synapse.storage import DataStore from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -21,6 +23,13 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +def __func__(inp): + if six.PY3: + return inp + else: + return inp.__func__ + + class SlavedDeviceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) @@ -38,14 +47,14 @@ class SlavedDeviceStore(BaseSlavedStore): "DeviceListFederationStreamChangeCache", device_list_max, ) - get_device_stream_token = DataStore.get_device_stream_token.__func__ - get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__ - get_devices_by_remote = DataStore.get_devices_by_remote.__func__ - _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__ - _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__ - mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__ + get_device_stream_token = __func__(DataStore.get_device_stream_token) + get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) + get_devices_by_remote = __func__(DataStore.get_devices_by_remote) + _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) + _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) + mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) _mark_as_sent_devices_by_remote_txn = ( - DataStore._mark_as_sent_devices_by_remote_txn.__func__ + __func__(DataStore._mark_as_sent_devices_by_remote_txn) ) count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 74e892c104..5dc7b3fffc 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections + (p.name,): len(p.pending_commands) for p in connected_connections }, ) @@ -607,9 +607,9 @@ def transport_buffer_size(protocol): transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections + (p.name,): transport_buffer_size(p) for p in connected_connections }, ) @@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True): tcp_transport_kernel_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + (p.name,): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, ) @@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge( tcp_transport_kernel_read_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_read_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + (p.name,): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, ) @@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge( tcp_inbound_commands = LaterGauge( "synapse_replication_tcp_protocol_inbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge( "synapse_replication_tcp_protocol_outbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 55fe701c5c..c1e626be3f 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -196,7 +196,7 @@ class Stream(object): ) if len(rows) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behined" % (self.NAME)) + raise Exception("stream %s has fallen behind" % (self.NAME)) else: rows = yield self.update_function( from_token, current_token, diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index ad536ab570..41534b8c2a 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet): nonce = self.hs.get_secrets().token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return (200, {"nonce": nonce.encode('ascii')}) + return (200, {"nonce": nonce}) @defer.inlineCallbacks def on_POST(self, request): @@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce) + want_mac.update(nonce.encode('utf8')) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + if not hmac.compare_digest( + want_mac.encode('ascii'), + got_mac.encode('ascii') + ): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 0f3a2e8b51..cd9b3bdbd1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet): is_guest = requester.is_guest room_id = None if is_guest: - if "room_id" not in request.args: + if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") - if "room_id" in request.args: - room_id = request.args["room_id"][0] + if b"room_id" in request.args: + room_id = request.args[b"room_id"][0].decode('ascii') pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: + if b"timeout" in request.args: try: - timeout = int(request.args["timeout"][0]) + timeout = int(request.args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index fd5f85b53e..3ead75cb77 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) content = yield self.initial_sync_handler.snapshot_all_rooms( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cb85fa1436..0010699d31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,9 @@ # limitations under the License. import logging -import urllib import xml.etree.ElementTree as ET -from six.moves.urllib import parse as urlparse +from six.moves import urllib from canonicaljson import json from saml2 import BINDING_HTTP_POST, config @@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: - relay_state = "&RelayState=" + urllib.quote( + relay_state = "&RelayState=" + urllib.parse.quote( login_submission["relay_state"]) result = { "uri": "%s%s" % (self.idp_redirect_url, relay_state) @@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet): (user_id, token) = yield handler.register_saml2(username) # Forward to the RelayState callback along with ava if 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + @@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet): "user_id": user_id, "token": token, "ava": saml2_auth.ava})) elif 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') finish_request(request) @@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url + self.cas_server_url = hs.config.cas_server_url.encode('ascii') + self.cas_service_url = hs.config.cas_service_url.encode('ascii') def on_GET(self, request): args = request.args - if "redirectUrl" not in args: + if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.urlencode({ - "redirectUrl": args["redirectUrl"][0] - }) - hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" - service_param = urllib.urlencode({ - "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) - }) - request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) + client_redirect_url_param = urllib.parse.urlencode({ + b"redirectUrl": args[b"redirectUrl"][0] + }).encode('ascii') + hs_redirect_url = (self.cas_service_url + + b"/_matrix/client/api/v1/login/cas/ticket") + service_param = urllib.parse.urlencode({ + b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) + }).encode('ascii') + request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - client_redirect_url = request.args["redirectUrl"][0] + client_redirect_url = request.args[b"redirectUrl"][0] http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { - "ticket": request.args["ticket"], + "ticket": request.args[b"ticket"][0].decode('ascii'), "service": self.cas_service_url } try: @@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet): finish_request(request) def add_login_token_to_redirect_url(self, url, token): - url_parts = list(urlparse.urlparse(url)) - query = dict(urlparse.parse_qsl(url_parts[4])) + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) query.update({"loginToken": token}) - url_parts[4] = urllib.urlencode(query) - return urlparse.urlunparse(url_parts) + url_parts[4] = urllib.parse.urlencode(query).encode('ascii') + return urllib.parse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): user = None diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6e95d9bec2..9382b1f124 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet): try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) requester = yield self.auth.get_user_by_req(request) @@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet): content, ) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) before = parse_string(request, "before") if before: @@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet): ) self.notify_user(user_id) except InconsistentRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) except RuleNotFoundException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) defer.returnValue((200, {})) @@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == b'': defer.returnValue((200, rules)) - elif path[0] == 'global': - path = path[1:] + elif path[0] == b'global': + path = [x.decode('ascii') for x in path[1:]] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) else: @@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet): def _rule_spec_from_path(path): if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != b'pushrules': raise UnrecognizedRequestError() - scope = path[1] + scope = path[1].decode('ascii') path = path[2:] if scope != 'global': raise UnrecognizedRequestError() @@ -203,13 +203,13 @@ def _rule_spec_from_path(path): if len(path) == 0: raise UnrecognizedRequestError() - template = path[0] + template = path[0].decode('ascii') path = path[1:] if len(path) == 0 or len(path[0]) == 0: raise UnrecognizedRequestError() - rule_id = path[0] + rule_id = path[0].decode('ascii') spec = { 'scope': scope, @@ -220,7 +220,7 @@ def _rule_spec_from_path(path): path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec['attr'] = path[0].decode('ascii') return spec diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 182a68b1e2..b84f0260f2 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet): ] for p in pushers: - for k, v in p.items(): + for k, v in list(p.items()): if k not in allowed_keys: del p[k] @@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet): profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + pce.message, + raise SynapseError(400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM) self.notifier.on_new_replication_data() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 976d98387d..663934efd0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), } - if 'ts' in request.args and requester.app_service: + if b'ts' in request.args and requester.app_service: event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event = yield self.event_creation_hander.create_and_send_nonmember_event( @@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = request.args["server_name"] + remote_room_hosts = [ + x.decode('ascii') for x in request.args[b"server_name"] + ] except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): @@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): pagination_config = PaginationConfig.from_request( request, default_limit=10, ) - as_client_event = "raw" not in request.args - filter_bytes = parse_string(request, "filter") + as_client_event = b"raw" not in request.args + filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet): # picking the API shape for symmetry with /messages filter_bytes = parse_string(request, "filter") if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes) event_filter = Filter(json.loads(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 62f4c3d93e..53da905eea 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -42,7 +42,11 @@ class VoipRestServlet(ClientV1RestServlet): expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 username = "%d:%s" % (expiry, requester.user.to_string()) - mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) + mac = hmac.new( + turnSecret.encode(), + msg=username.encode(), + digestmod=hashlib.sha1 + ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the # same result as the TURN server. diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py index 5e99cffbcb..dadb376b02 100644 --- a/synapse/rest/client/v1_only/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -23,6 +23,7 @@ from twisted.internet import defer import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.rest.client.v1.base import ClientV1RestServlet from synapse.types import create_requester @@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet): register_json["user"].encode("utf-8") if "user" in register_json else None ) + threepid = None + if session.get(LoginType.EMAIL_IDENTITY): + threepid = session["threepidCreds"] handler = self.handlers.registration_handler (user_id, token) = yield handler.register( localpart=desired_user_id, - password=password + password=password, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(user_id) if session[LoginType.EMAIL_IDENTITY]: logger.debug("Binding emails %s to %s" % ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 372648cafd..37b32dd37b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( @@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.datastore.get_user_id_by_threepid( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2f64155d13..192f52e462 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,6 +26,7 @@ import synapse import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", body['email']): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Your email domain is not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( - 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, + 403, + "Phone numbers are not authorized to register on this server", + Codes.THREEPID_DENIED, ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( @@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet): if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( - 403, "Third party identifier is not allowed", + 403, + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) @@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet): if desired_username is not None: desired_username = desired_username.lower() + threepid = None + if auth_result: + threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, generate_token=False, + threepid=threepid, ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if is_threepid_reserved(self.hs.config, threepid): + yield self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 1275baa1ba..0251146722 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, + format_event_raw, serialize_event, ) from synapse.handlers.presence import format_user_presence_state @@ -88,7 +89,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - if "from" in request.args: + if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. raise SynapseError( @@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_response(time_now, sync_result, access_token_id, filter): + if filter.event_format == 'client': + event_formatter = format_event_for_client_v2_without_room_id + elif filter.event_format == 'federation': + event_formatter = format_event_raw + else: + raise Exception("Unknown event format %s" % (filter.event_format, )) + joined = SyncRestServlet.encode_joined( - sync_result.joined, time_now, access_token_id, filter.event_fields + sync_result.joined, time_now, access_token_id, + filter.event_fields, + event_formatter, ) invited = SyncRestServlet.encode_invited( sync_result.invited, time_now, access_token_id, + event_formatter, ) archived = SyncRestServlet.encode_archived( sync_result.archived, time_now, access_token_id, filter.event_fields, + event_formatter, ) return { @@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet): } @staticmethod - def encode_joined(rooms, time_now, token_id, event_fields): + def encode_joined(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the joined rooms in a sync result @@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, only_fields=event_fields + room, time_now, token_id, joined=True, only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_invited(rooms, time_now, token_id): + def encode_invited(rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet): time_now(int): current time - used as a baseline for age calculations token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs + of transaction IDs + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: the invited rooms list, in our @@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet): for room in rooms: invite = serialize_event( room.invite, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, is_invite=True, ) unsigned = dict(invite.get("unsigned", {})) @@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet): return invited @staticmethod - def encode_archived(rooms, time_now, token_id, event_fields): + def encode_archived(rooms, time_now, token_id, event_fields, event_formatter): """ Encode the archived rooms in a sync result @@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet): token_id(int): ID of the user's auth token - used for namespacing of transaction IDs event_fields(list<str>): List of event fields to include. If empty, - all fields will be returned. + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = SyncRestServlet.encode_room( - room, time_now, token_id, joined=False, only_fields=event_fields + room, time_now, token_id, joined=False, + only_fields=event_fields, + event_formatter=event_formatter, ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True, only_fields=None): + def encode_room( + room, time_now, token_id, joined, + only_fields, event_formatter, + ): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet): joined (bool): True if the user is joined to this room - will mean we handle ephemeral events only_fields(list<str>): Optional. The list of event fields to include. + event_formatter (func[dict]): function to convert from federation format + to client format Returns: dict[str, object]: the room, encoded in our response format """ def serialize(event): - # TODO(mjark): Respect formatting requirements in the filter. return serialize_event( event, time_now, token_id=token_id, - event_format=format_event_for_client_v2_without_room_id, + event_format=event_formatter, only_event_fields=only_fields, ) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d9d379182e..b9b5d07677 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields @@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index b9ee6e1c13..38eb2ee23f 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -88,5 +88,5 @@ class LocalKey(Resource): ) def getChild(self, name, request): - if name == '': + if name == b'': return self diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 3491fd2118..cb5abcf826 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey class KeyApiV2Resource(Resource): def __init__(self, hs): Resource.__init__(self) - self.putChild("server", LocalKey(hs)) - self.putChild("query", RemoteKey(hs)) + self.putChild(b"server", LocalKey(hs)) + self.putChild(b"query", RemoteKey(hs)) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7d67e4b064..eb8782aa6e 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,7 +103,7 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server: {}} + query = {server.decode('ascii'): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer( @@ -112,11 +112,12 @@ class RemoteKey(Resource): arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} else: raise SynapseError( 404, "Not found %r" % request.postpath, Codes.NOT_FOUND ) + yield self.query_keys(request, query, query_remote_on_cache_miss=True) def render_POST(self, request): @@ -135,6 +136,7 @@ class RemoteKey(Resource): @defer.inlineCallbacks def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.info("Handling query for keys %r", query) + store_queries = [] for server_name, key_ids in query.items(): if ( diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index f255f2883f..5a426ff2f6 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.split('/')[-1] + filename = request.path.decode('ascii').split('/')[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource): # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) request.setHeader( - "Cache-Control", "public,max-age=86400,s-maxage=86400" + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" ) d = FileSender().beginFileTransfer(f, request) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65f4bd2910..76e479afa3 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -15,9 +15,8 @@ import logging import os -import urllib -from six.moves.urllib import parse as urlparse +from six.moves import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -35,10 +34,15 @@ def parse_media_id(request): # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. server_name, media_id = request.postpath[:2] + + if isinstance(server_name, bytes): + server_name = server_name.decode('utf-8') + media_id = media_id.decode('utf8') + file_name = None if len(request.postpath) > 2: try: - file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") + file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) except UnicodeDecodeError: pass return server_name, media_id, file_name @@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name): file_size (int): Size in bytes of the media, if known. upload_name (str): The name of the requested file, if any. """ + def _quote(x): + return urllib.parse.quote(x.encode("utf-8")) + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) if upload_name: if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii") else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ( + "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii") + + request.setHeader(b"Content-Disposition", disposition) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index fbfa85f74f..f911b120b1 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -47,12 +47,13 @@ class DownloadResource(Resource): def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( - "Content-Security-Policy", - "default-src 'none';" - " script-src 'none';" - " plugin-types application/pdf;" - " style-src 'unsafe-inline';" - " object-src 'self';" + b"Content-Security-Policy", + b"default-src 'none';" + b" script-src 'none';" + b" plugin-types application/pdf;" + b" style-src 'unsafe-inline';" + b" media-src 'self';" + b" object-src 'self';" ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 241c972070..a828ff4438 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -20,7 +20,7 @@ import logging import os import shutil -from six import iteritems +from six import PY3, iteritems from six.moves.urllib import parse as urlparse import twisted.internet.error @@ -397,13 +397,13 @@ class MediaRepository(object): yield finish() - media_type = headers["Content-Type"][0] + media_type = headers[b"Content-Type"][0].decode('ascii') time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) + _, params = cgi.parse_header(content_disposition[0].decode('ascii'),) upload_name = None # First check if there is a valid UTF-8 filename @@ -419,9 +419,13 @@ class MediaRepository(object): upload_name = upload_name_ascii if upload_name: - upload_name = urlparse.unquote(upload_name) + if PY3: + upload_name = urlparse.unquote(upload_name) + else: + upload_name = urlparse.unquote(upload_name.encode('ascii')) try: - upload_name = upload_name.decode("utf-8") + if isinstance(upload_name, bytes): + upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: @@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource): Resource.__init__(self) media_repo = hs.get_media_repository() - self.putChild("upload", UploadResource(hs, media_repo)) - self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource( + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild(b"thumbnail", ThumbnailResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("identicon", IdenticonResource()) + self.putChild(b"identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource( + self.putChild(b"preview_url", PreviewUrlResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("config", MediaConfigResource(hs)) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 778ef97337..af01040a38 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -79,7 +79,6 @@ class PreviewUrlResource(Resource): # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self._cache.start() self._cleaner_loop = self.clock.looping_call( self._start_expire_url_cache_data, 10 * 1000, @@ -261,7 +260,7 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og) + jsonog = json.dumps(og).encode('utf8') # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -301,20 +300,20 @@ class PreviewUrlResource(Resource): logger.warn("Error downloading %s: %r", url, e) raise SynapseError( 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_type, e), + traceback.format_exception_only(sys.exc_info()[0], e), ), Codes.UNKNOWN, ) yield finish() try: - if "Content-Type" in headers: - media_type = headers["Content-Type"][0] + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode('ascii') else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) download_name = None diff --git a/synapse/server.py b/synapse/server.py index a795643add..3e9d3d8256 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,6 +19,7 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation +import abc import logging from twisted.enterprise import adbapi @@ -82,7 +83,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStore from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -112,6 +112,8 @@ class HomeServer(object): config (synapse.config.homeserver.HomeserverConfig): """ + __metaclass__ = abc.ABCMeta + DEPENDENCIES = [ 'http_client', 'db_pool', @@ -174,6 +176,11 @@ class HomeServer(object): 'room_context_handler', ] + # This is overridden in derived application classes + # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be + # instantiated during setup() for future return by get_datastore() + DATASTORE_CLASS = abc.abstractproperty() + def __init__(self, hostname, reactor=None, **kwargs): """ Args: @@ -190,13 +197,16 @@ class HomeServer(object): self.distributor = Distributor() self.ratelimiter = Ratelimiter() + self.datastore = None + # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) def setup(self): logger.info("Setting up.") - self.datastore = DataStore(self.get_db_conn(), self) + with self.get_db_conn() as conn: + self.datastore = self.DATASTORE_CLASS(conn, self) logger.info("Finished setting up.") def get_reactor(self): diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 893b9001cd..af15cba0ee 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -118,7 +118,7 @@ class ResourceLimitsServerNotices(object): 'body': event_content, 'msgtype': ServerNoticeMsgType, 'server_notice_type': ServerNoticeLimitReached, - 'admin_uri': self._config.admin_uri, + 'admin_contact': self._config.admin_contact, 'limit_type': event_limit_type } event = yield self._server_notices_manager.send_notice( @@ -146,11 +146,10 @@ class ResourceLimitsServerNotices(object): user_id(str): the user in question room_id(str): the server notices room for that user """ - tags = yield self._store.get_tags_for_user(user_id) - server_notices_tags = tags.get(room_id) + tags = yield self._store.get_tags_for_room(user_id, room_id) need_to_set_tag = True - if server_notices_tags: - if server_notices_tags.get(SERVER_NOTICE_ROOM_TAG): + if tags: + if SERVER_NOTICE_ROOM_TAG in tags: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b34970e4d1..b22495c1f9 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -95,10 +95,6 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - def start_caching(self): - # TODO: remove this shim - self._state_resolution_handler.start_caching() - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", latest_event_ids=None): @@ -385,6 +381,7 @@ class StateHandler(object): ev_ids, get_prev_content=False, check_redacted=False, ) + @defer.inlineCallbacks def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) @@ -401,15 +398,17 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map( - room_version, state_set_ids, state_map, + new_state = yield resolve_events_with_factory( + room_version, state_set_ids, + event_map=state_map, + state_map_factory=self._state_map_factory ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) } - return new_state + defer.returnValue(new_state) class StateResolutionHandler(object): @@ -425,9 +424,6 @@ class StateResolutionHandler(object): self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") - def start_caching(self): - logger.debug("start_caching") - self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -437,8 +433,6 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - self._state_cache.start() - @defer.inlineCallbacks @log_function def resolve_state_groups( @@ -589,31 +583,6 @@ def _make_state_cache_entry( ) -def resolve_events_with_state_map(room_version, state_sets, state_map): - """ - Args: - room_version(str): Version of the room - state_sets(list): List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - state_map(dict): a dict from event_id to event, for all events in - state_sets. - - Returns - dict[(str, str), str]: - a map from (type, state_key) to event_id. - """ - if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): - return v1.resolve_events_with_state_map( - state_sets, state_map, - ) - else: - # This should only happen if we added a version but forgot to add it to - # the list above. - raise Exception( - "No state resolution algorithm defined for version %r" % (room_version,) - ) - - def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): """ Args: diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 3a1f7054a1..7a7157b352 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -30,34 +30,6 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -def resolve_events_with_state_map(state_sets, state_map): - """ - Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - state_map(dict): a dict from event_id to event, for all events in - state_sets. - - Returns - dict[(str, str), str]: - a map from (type, state_key) to event_id. - """ - if len(state_sets) == 1: - return state_sets[0] - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - ) - - @defer.inlineCallbacks def resolve_events_with_factory(state_sets, event_map, state_map_factory): """ @@ -93,10 +65,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): for event_ids in itervalues(conflicted_state) for event_id in event_ids ) + needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) - logger.info("Asking for %d conflicted events", len(needed_events)) + logger.info( + "Asking for %d/%d conflicted events", + len(needed_events), + needed_event_count, + ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) @@ -113,11 +90,16 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): ) new_needed_events = set(itervalues(auth_events)) + new_needed_event_count = len(new_needed_events) new_needed_events -= needed_events if event_map is not None: new_needed_events -= set(iterkeys(event_map)) - logger.info("Asking for %d auth events", len(new_needed_events)) + logger.info( + "Asking for %d/%d auth events", + len(new_needed_events), + new_needed_event_count, + ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 08dffd774f..be61147b9b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,9 +17,10 @@ import sys import threading import time -from six import iteritems, iterkeys, itervalues +from six import PY2, iteritems, iterkeys, itervalues from six.moves import intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8fc678fa67..9ad17b7c25 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": last_seen, - }, - lock=False, - ) + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) @defer.inlineCallbacks def get_last_client_ip_by_device(self, user_id, device_id): diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 73646da025..e06b0bc56d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} - devices = messages_by_device.keys() + devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c0943ecf91..d10ff9e4b9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,7 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore +from ._base import Cache, SQLBaseStore, db_to_json logger = logging.getLogger(__name__) @@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore): if device is not None: key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(json.loads(content)) + defer.returnValue(db_to_json(content)) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore): desc="_get_cached_devices_for_user", ) defer.returnValue({ - device["device_id"]: json.loads(device["content"]) + device["device_id"]: db_to_json(device["content"]) for device in devices }) @@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 808194236a..cfb687cb53 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): }, retcol="creator", desc="get_room_alias_creator", - allow_none=True ) @cached(max_entries=5000) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 523b4360c3..1f1721e820 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,13 +14,13 @@ # limitations under the License. from six import iteritems -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class EndToEndKeyStore(SQLBaseStore): @@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = json.loads(device_info.pop("key_json")) + device_info["keys"] = db_to_json(device_info.pop("key_json")) defer.returnValue(results) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 8a0386c1a4..42225f8a2a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -41,13 +41,18 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + + # Set the bytea output to escape, vs the default of hex + cursor = db_conn.cursor() + cursor.execute("SET bytea_output TO escape") + # Asynchronous commit, don't wait for the server to call fsync before # ending the transaction. # https://www.postgresql.org/docs/current/static/wal-async-commit.html if not self.synchronous_commit: - cursor = db_conn.cursor() cursor.execute("SET synchronous_commit TO OFF") - cursor.close() + + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f39c8c8461..03cedf3a75 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict, deque, namedtuple from functools import wraps -from six import iteritems +from six import iteritems, text_type from six.moves import range from canonicaljson import json @@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder @@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) for room_id, ev_ctx_rm in iteritems(events_by_room): - # Work out new extremities by recursively adding and removing - # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremeties( + new_latest_event_ids = yield self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # No change in extremities, so no change in state continue + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + new_forward_extremeties[room_id] = new_latest_event_ids len_1 = ( @@ -517,44 +522,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremeties for a room given events to + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - new_latest_event_ids = set(latest_event_ids) - # First, add all the new events to the list - new_latest_event_ids.update( - event.event_id for event, ctx in event_contexts + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event for event, ctx in event_contexts if not event.internal_metadata.is_outlier() and not ctx.rejected + ] + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update( + event.event_id for event in new_events ) - # Now remove all events that are referenced by the to-be-added events - new_latest_event_ids.difference_update( + + # Now remove all events which are prev_events of any of the new events + result.difference_update( e_id - for event, ctx in event_contexts + for event in new_events for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() and not ctx.rejected ) - # And finally remove any events that are referenced by previously added - # events. - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=list(new_latest_event_ids), - retcols=["prev_event_id"], - keyvalues={ - "is_state": False, - }, - desc="_calculate_new_extremeties", - ) + # Finally, remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) - new_latest_event_ids.difference_update( - row["prev_event_id"] for row in rows - ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events(txn, batch): + sql = """ + SELECT prev_event_id + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + AND rejections.event_id IS NULL + """ % ( + ",".join("?" for _ in batch), + ) + + txn.execute(sql, batch) + results.extend(r[0] for r in txn) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events, + chunk, + ) - defer.returnValue(new_latest_event_ids) + defer.returnValue(results) @defer.inlineCallbacks def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, @@ -586,10 +626,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore the new current state is only returned if we've already calculated it. """ - - if not new_latest_event_ids: - return - # map from state_group to ((type, key) -> event_id) state map state_groups_map = {} @@ -930,6 +966,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) self._invalidate_cache_and_stream( + txn, self.get_room_summary, (room_id,) + ) + + self._invalidate_cache_and_stream( txn, self.get_current_state_ids, (room_id,) ) @@ -1220,7 +1260,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "sender": event.sender, "contains_url": ( "url" in event.content - and isinstance(event.content["url"], basestring) + and isinstance(event.content["url"], text_type) ), } for event, _ in events_and_contexts @@ -1529,7 +1569,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], basestring) + contains_url &= isinstance(content["url"], text_type) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -1886,20 +1926,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ")" ) - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)", - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute( - "CREATE INDEX events_to_purge_id" - " ON events_to_purge(event_id)", - ) - # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -1910,9 +1936,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (room_id,) ) rows = txn.fetchall() - max_depth = max(row[0] for row in rows) + max_depth = max(row[1] for row in rows) - if max_depth <= token.topological: + if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) @@ -1926,19 +1952,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore should_delete_params = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" - should_delete_params += ("%:" + self.hs.hostname, ) + + # We include the parameter twice since we use the expression twice + should_delete_params += ( + "%:" + self.hs.hostname, + "%:" + self.hs.hostname, + ) should_delete_params += (room_id, token.topological) + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. txn.execute( "INSERT INTO events_to_purge" " SELECT event_id, %s" " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE e.room_id = ? AND topological_ordering < ?" % ( + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % ( + should_delete_expr, should_delete_expr, ), should_delete_params, ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute( + "CREATE INDEX events_to_purge_id" + " ON events_to_purge(event_id)", + ) + txn.execute( "SELECT event_id, should_delete FROM events_to_purge" ) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 59822178ff..a8326f5296 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import logging from collections import namedtuple @@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = zip(*event_list)[0] + event_id_lists = list(zip(*event_list))[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] @@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs): + def fire(evs, exc): for _, d in evs: if not d.called: with PreserveLoggingContext(): - d.errback(e) + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 2d5896c5b4..6ddcc909bf 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) + defer.returnValue(db_to_json(def_json)) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index f547977600..a1331c1a61 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore): """ key_id = "%s:%s" % (verify_key.alg, verify_key.version) + # XXX fix this to not need a lock (#3819) def _txn(txn): self._simple_upsert_txn( txn, diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d178f5c5ba..0fe8c8e24c 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def initialise_reserved_users(self, threepids): - # TODO Why can't I do this in init? store = self.hs.get_datastore() reserved_user_list = [] @@ -148,6 +147,23 @@ class MonthlyActiveUsersStore(SQLBaseStore): return self.runInteraction("count_users", _count_users) @defer.inlineCallbacks + def get_registered_reserved_users_count(self): + """Of the reserved threepids defined in config, how many are associated + with registered users? + + Returns: + Defered[int]: Number of real reserved users + """ + count = 0 + for tp in self.hs.config.mau_limits_reserved_threepids: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + count = count + 1 + defer.returnValue(count) + + @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): """ Updates or inserts monthly active user member @@ -156,6 +172,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[bool]: True if a new entry was created, False if an existing one was updated. """ + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more is_insert = yield self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", @@ -165,7 +185,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={ "timestamp": int(self._clock.time_msec()), }, - lock=False, ) if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) @@ -200,10 +219,14 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ + if self.hs.config.limit_usage_by_mau: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return is_trial = yield self.is_trial_user(user_id) if is_trial: - # we don't track trial users in the MAU table. return last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8443bd4c1b..c7987bfcdd 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import types + +import six from canonicaljson import encode_canonical_json, json @@ -27,6 +28,11 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview + class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): @@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore): dataJson = r['data'] r['data'] = None try: - if isinstance(dataJson, types.BufferType): + if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") r['data'] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.message, + r['id'], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], types.BufferType): + if isinstance(r['pushkey'], db_binary_type): r['pushkey'] = str(r['pushkey']).decode("UTF8") return rows diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9b4e6d6aa8..0707f9a86a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -51,6 +51,12 @@ ProfileInfo = namedtuple( "ProfileInfo", ("avatar_url", "display_name") ) +# "members" points to a truncated list of (user_id, event_id) tuples for users of +# a given membership type, suitable for use in calculating heroes for a room. +# "count" points to the total numberr of users of a given membership type. +MemberSummary = namedtuple( + "MemberSummary", ("members", "count") +) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + sql = """ + SELECT m.user_id, m.membership, m.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + m.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 4b971efdba..3f4cbd61c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): member events (if True), or to exclude member events (if False) Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types is not None: non_member_types = [t for t in types if t[0] != EventTypes.Member] @@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types: types = frozenset(types) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 428e7fa36e..a3032cdce9 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,14 +18,14 @@ from collections import namedtuple import six -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.descriptors import cached +from synapse.util.caches.expiringcache import ExpiringCache -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview @@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple( ) ) +SENTINEL = object() + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. @@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore): self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -95,7 +103,8 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], json.loads(str(result["response_json"])) + return result["response_code"], db_to_json(result["response_json"]) + else: return None @@ -155,7 +164,7 @@ class TransactionStore(SQLBaseStore): """ pass - @cached(max_entries=10000) + @defer.inlineCallbacks def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -166,10 +175,20 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a dict for the retry scheme """ - return self.runInteraction( + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + defer.returnValue(result) + + result = yield self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + defer.returnValue(result) + def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, @@ -197,8 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # XXX: we could chose to not bother persisting this if our cache thinks - # this is a NOOP + self._destination_retry_cache.pop(destination, None) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, @@ -211,10 +229,6 @@ class TransactionStore(SQLBaseStore): retry_last_ts, retry_interval): self.database_engine.lock_table(txn, "destinations") - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 9b3f2f4b96..ec7b2c9672 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -374,29 +374,25 @@ class ReadWriteLock(object): defer.returnValue(_ctx_manager()) -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise defer.TimeoutError(timeout, "Deferred") + return value -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. +def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """The in built twisted `Deferred.addTimeout` fails to time out deferreds + that have a canceller that throws exceptions. This method creates a new + deferred that wraps and times out the given deferred, correctly handling + the case where the given deferred's canceller throws. - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - + deferred (Deferred) + timeout (float): Timeout in seconds + reactor (twisted.internet.reactor): The twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. @@ -406,13 +402,26 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): the timeout. The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. + CancelledError Failure into a defer.TimeoutError. + + Returns: + Deferred """ + + new_d = defer.Deferred() + timed_out = [False] def time_it_out(): timed_out[0] = True - deferred.cancel() + + try: + deferred.cancel() + except: # noqa: E722, if we throw any exception it'll break time outs + logger.exception("Canceller failed during timeout") + + if not new_d.called: + new_d.errback(defer.TimeoutError(timeout, "Deferred")) delayed_call = reactor.callLater(timeout, time_it_out) @@ -432,9 +441,14 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addBoth(cancel_timeout) + def success_cb(val): + if not new_d.called: + new_d.callback(val) -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value + def failure_cb(val): + if not new_d.called: + new_d.errback(val) + + deferred.addCallbacks(success_cb, failure_cb) + + return new_d diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 7b065b195e..f37d5bec08 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import six @@ -20,6 +21,8 @@ from six.moves import intern from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily +logger = logging.getLogger(__name__) + CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) @@ -76,16 +79,20 @@ def register_cache(cache_type, cache_name, cache): return [] def collect(self): - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) + try: + if cache_type == "response_cache": + response_cache_size.labels(cache_name).set(len(cache)) + response_cache_hits.labels(cache_name).set(self.hits) + response_cache_evicted.labels(cache_name).set(self.evicted_size) + response_cache_total.labels(cache_name).set(self.hits + self.misses) + else: + cache_size.labels(cache_name).set(len(cache)) + cache_hits.labels(cache_name).set(self.hits) + cache_evicted.labels(cache_name).set(self.evicted_size) + cache_total.labels(cache_name).set(self.hits + self.misses) + except Exception as e: + logger.warn("Error calculating metrics for %s: %s", cache_name, e) + raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ce85b2ae11..f369780277 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -16,12 +16,17 @@ import logging from collections import OrderedDict +from six import iteritems, itervalues + from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache logger = logging.getLogger(__name__) +SENTINEL = object() + + class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, reset_expiry_on_get=False, iterable=False): @@ -54,11 +59,8 @@ class ExpiringCache(object): self.iterable = iterable - self._size_estimate = 0 - self.metrics = register_cache("expiring", cache_name, self) - def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire return @@ -75,16 +77,11 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - if self.iterable: - self._size_estimate += len(value) - # Evict if there are now too many items while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - removed_len = len(value.value) - self.metrics.inc_evictions(removed_len) - self._size_estimate -= removed_len + self.metrics.inc_evictions(len(value.value)) else: self.metrics.inc_evictions() @@ -101,6 +98,21 @@ class ExpiringCache(object): return entry.value + def pop(self, key, default=SENTINEL): + """Removes and returns the value with the given key from the cache. + + If the key isn't in the cache then `default` will be returned if + specified, otherwise `KeyError` will get raised. + + Identical functionality to `dict.pop(..)`. + """ + + value = self._cache.pop(key, default) + if value is SENTINEL: + raise KeyError(key) + + return value + def __contains__(self, key): return key in self._cache @@ -128,14 +140,16 @@ class ExpiringCache(object): keys_to_delete = set() - for key, cache_entry in self._cache.items(): + for key, cache_entry in iteritems(self._cache): if now - cache_entry.time > self._expiry_ms: keys_to_delete.add(key) for k in keys_to_delete: value = self._cache.pop(k) if self.iterable: - self._size_estimate -= len(value.value) + self.metrics.inc_evictions(len(value.value)) + else: + self.metrics.inc_evictions() logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -144,12 +158,14 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return self._size_estimate + return sum(len(entry.value) for entry in itervalues(self._cache)) else: return len(self._cache) class _CacheEntry(object): + __slots__ = ["time", "value"] + def __init__(self, time, value): self.time = time self.value = value diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a0c2d37610..89224b26cc 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -200,7 +200,7 @@ class LoggingContext(object): sentinel = Sentinel() - def __init__(self, name=None, parent_context=None): + def __init__(self, name=None, parent_context=None, request=None): self.previous_context = LoggingContext.current_context() self.name = name @@ -218,6 +218,13 @@ class LoggingContext(object): self.parent_context = parent_context + if self.parent_context is not None: + self.parent_context.copy_to(self) + + if request is not None: + # the request param overrides the request from the parent context + self.request = request + def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -256,9 +263,6 @@ class LoggingContext(object): ) self.alive = True - if self.parent_context is not None: - self.parent_context.copy_to(self) - return self def __exit__(self, type, value, traceback): @@ -439,6 +443,35 @@ class PreserveLoggingContext(object): ) +def nested_logging_context(suffix, parent_context=None): + """Creates a new logging context as a child of another. + + The nested logging context will have a 'request' made up of the parent context's + request, plus the given suffix. + + CPU/db usage stats will be added to the parent context's on exit. + + Normal usage looks like: + + with nested_logging_context(suffix): + # ... do stuff + + Args: + suffix (str): suffix to add to the parent context's 'request'. + parent_context (LoggingContext|None): parent context. Will use the current context + if None. + + Returns: + LoggingContext: new logging context. + """ + if parent_context is None: + parent_context = LoggingContext.current_context() + return LoggingContext( + parent_context=parent_context, + request=parent_context.request + "-" + suffix, + ) + + def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" def g(*args, **kwargs): diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 14be3c7396..8d0f2a8918 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -19,22 +19,40 @@ from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal PUBLIC_KEY = ( - "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" - "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS" - "kbh/C+BR3utDS555mV" + "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" + "XALqeK+7385NlLja3DE/DO9mGhnd9+bAy39EKT3sTV6+WXQ4yD0TvEEyUEMtjWkSEm6U32+C" + "DaS3TW/vPBUMeJQwq+Ydcif1UlnpXrDDTamD0AU9VaEvHq+3HAkipqn0TGpKON6aqk4vauDx" + "oXSsV5TXBVrxP/y7HpMOpU4GUWsaaacBTKKNnUaQB4UflvydaPJUuwdaCUJGTMjbhWrjVfK+" + "jslseSPxU6XvrkZMyCr4znxvuDxjMk1RGIdO7v+rbBMLEgqtSMNqJbYeVCnj2CFgc3fcTcld" + "X2uOJDrJb/WRlHulthCh" ) PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- -MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW -4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw -vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb -Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1 -xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8 -PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2 -gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu -DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML -pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP -EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg== +MIIEpQIBAAKCAQEAx4RgE2luCoRNt/u56x+Ixcd8i6vTo2hLOVwC6nivu9/OTZS4 +2twxPwzvZhoZ3ffmwMt/RCk97E1evll0OMg9E7xBMlBDLY1pEhJulN9vgg2kt01v +7zwVDHiUMKvmHXIn9VJZ6V6ww02pg9AFPVWhLx6vtxwJIqap9ExqSjjemqpOL2rg +8aF0rFeU1wVa8T/8ux6TDqVOBlFrGmmnAUyijZ1GkAeFH5b8nWjyVLsHWglCRkzI +24Vq41Xyvo7JbHkj8VOl765GTMgq+M58b7g8YzJNURiHTu7/q2wTCxIKrUjDaiW2 +HlQp49ghYHN33E3JXV9rjiQ6yW/1kZR7pbYQoQIDAQABAoIBAQC8KJ0q8Wzzwh5B +esa1dQHZ8+4DEsL/Amae66VcVwD0X3cCN1W2IZ7X5W0Ij2kBqr8V51RYhcR+S+Ek +BtzSiBUBvbKGrqcMGKaUgomDIMzai99hd0gvCCyZnEW1OQhFkNkaRNXCfqiZJ27M +fqvSUiU2eOwh9fCvmxoA6Of8o3FbzcJ+1GMcobWRllDtLmj6lgVbDzuA+0jC5daB +9Tj1pBzu3wn3ufxiS+gBnJ+7NcXH3E73lqCcPa2ufbZ1haxfiGCnRIhFXuQDgxFX +vKdEfDgtvas6r1ahGbc+b/q8E8fZT7cABuIU4yfOORK+MhpyWbvoyyzuVGKj3PKt +KSPJu5CZAoGBAOkoJfAVyYteqKcmGTanGqQnAY43CaYf6GdSPX/jg+JmKZg0zqMC +jWZUtPb93i+jnOInbrnuHOiHAxI8wmhEPed28H2lC/LU8PzlqFkZXKFZ4vLOhhRB +/HeHCFIDosPFlohWi3b+GAjD7sXgnIuGmnXWe2ea/TS3yersifDEoKKjAoGBANsQ +gJX2cJv1c3jhdgcs8vAt5zIOKcCLTOr/QPmVf/kxjNgndswcKHwsxE/voTO9q+TF +v/6yCSTxAdjuKz1oIYWgi/dZo82bBKWxNRpgrGviU3/zwxiHlyIXUhzQu78q3VS/ +7S1XVbc7qMV++XkYKHPVD+nVG/gGzFxumX7MLXfrAoGBAJit9cn2OnjNj9uFE1W6 +r7N254ndeLAUjPe73xH0RtTm2a4WRopwjW/JYIetTuYbWgyujc+robqTTuuOZjAp +H/CG7o0Ym251CypQqaFO/l2aowclPp/dZhpPjp9GSjuxFBZLtiBB3DNBOwbRQzIK +/vLTdRQvZkgzYkI4i0vjNt3JAoGBANP8HSKBLymMlShlrSx2b8TB9tc2Y2riohVJ +2ttqs0M2kt/dGJWdrgOz4mikL+983Olt/0P9juHDoxEEMK2kpcPEv40lnmBpYU7h +s8yJvnBLvJe2EJYdJ8AipyAhUX1FgpbvfxmASP8eaUxsegeXvBWTGWojAoS6N2o+ +0KSl+l3vAoGAFqm0gO9f/Q1Se60YQd4l2PZeMnJFv0slpgHHUwegmd6wJhOD7zJ1 +CkZcXwiv7Nog7AI9qKJEUXLjoqL+vJskBzSOqU3tcd670YQMi1aXSXJqYE202K7o +EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 97f1267380..4b4ac5f6c7 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -20,6 +20,7 @@ from prometheus_client import Counter from twisted.internet import defer +from synapse.metrics import InFlightGauge from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) @@ -45,6 +46,13 @@ block_db_txn_duration = Counter( block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) +# Tracks the number of blocks currently active +in_flight = InFlightGauge( + "synapse_util_metrics_block_in_flight", "", + labels=["block_name"], + sub_metrics=["real_time_max", "real_time_sum"], +) + def measure_func(name): def wrapper(func): @@ -82,10 +90,14 @@ class Measure(object): self.start_usage = self.start_context.get_resource_usage() + in_flight.register((self.name,), self._update_in_flight) + def __exit__(self, exc_type, exc_val, exc_tb): if isinstance(exc_type, Exception) or not self.start_context: return + in_flight.unregister((self.name,), self._update_in_flight) + duration = self.clock.time() - self.start block_counter.labels(self.name).inc() @@ -120,3 +132,13 @@ class Measure(object): if self.created_context: self.start_context.__exit__(exc_type, exc_val, exc_tb) + + def _update_in_flight(self, metrics): + """Gets called when processing in flight metrics + """ + duration = self.clock.time() - self.start + + metrics.real_time_max = max(metrics.real_time_max, duration) + metrics.real_time_sum += duration + + # TODO: Add other in flight metrics. diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 8a3a06fd74..26cce7d197 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -188,7 +188,7 @@ class RetryDestinationLimiter(object): else: self.retry_interval = self.min_retry_interval - logger.debug( + logger.info( "Connection to %s was unsuccessful (%s(%s)); backoff now %i", self.destination, exc_type, exc_val, self.retry_interval ) diff --git a/synapse/visibility.py b/synapse/visibility.py index d4680863d3..c64ad2144c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -324,14 +324,13 @@ def filter_events_for_server(store, server_name, events): # server's domain. # # event_to_state_ids contains lots of duplicates, so it turns out to be - # cheaper to build a complete set of unique - # ((type, state_key), event_id) tuples, and then filter out the ones we - # don't want. + # cheaper to build a complete event_id => (type, state_key) dict, and then + # filter out the ones we don't want # - state_key_to_event_id_set = { - e + event_id_to_state_key = { + event_id: key for key_to_eid in itervalues(event_to_state_ids) - for e in key_to_eid.items() + for key, event_id in iteritems(key_to_eid) } def include(typ, state_key): @@ -346,7 +345,7 @@ def filter_events_for_server(store, server_name, events): event_map = yield store.get_events([ e_id - for key, e_id in state_key_to_event_id_set + for e_id, key in iteritems(event_id_to_state_key) if include(key[0], key[1]) ]) |