diff --git a/synapse/__init__.py b/synapse/__init__.py
index e62901b761..43c5821ade 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,4 +17,14 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.33.3"
+try:
+ from twisted.internet import protocol
+ from twisted.internet.protocol import Factory
+ from twisted.names.dns import DNSDatagramProtocol
+ protocol.Factory.noisy = False
+ Factory.noisy = False
+ DNSDatagramProtocol.noisy = False
+except ImportError:
+ pass
+
+__version__ = "0.33.6"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 8d2aa5870a..34382e4e3c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -26,6 +26,7 @@ import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.config.server import is_threepid_reserved
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -775,13 +776,19 @@ class Auth(object):
)
@defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None):
+ def check_auth_blocking(self, user_id=None, threepid=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
+
+ threepid(dict|None): If present, checks for presence against configured
+ reserved threepid. Used in cases where the user is trying register
+ with a MAU blocked server, normally they would be rejected but their
+ threepid is on the reserved list. user_id and
+ threepid should never be set at the same time.
"""
# Never fail an auth check for the server notices users
@@ -793,10 +800,12 @@ class Auth(object):
raise ResourceLimitError(
403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- admin_uri=self.hs.config.admin_uri,
+ admin_contact=self.hs.config.admin_contact,
limit_type=self.hs.config.hs_disabled_limit_type
)
if self.hs.config.limit_usage_by_mau is True:
+ assert not (user_id and threepid)
+
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
@@ -806,13 +815,17 @@ class Auth(object):
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
+ elif threepid:
+ # If the user does not exist yet, but is signing up with a
+ # reserved threepid then pass auth check
+ if is_threepid_reserved(self.hs.config, threepid):
+ return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded",
-
- admin_uri=self.hs.config.admin_uri,
+ admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type="monthly_active_user"
)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 140dbfe8b8..4124469442 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -240,10 +240,10 @@ class ResourceLimitError(SynapseError):
def __init__(
self, code, msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- admin_uri=None,
+ admin_contact=None,
limit_type=None,
):
- self.admin_uri = admin_uri
+ self.admin_contact = admin_contact
self.limit_type = limit_type
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
@@ -251,7 +251,7 @@ class ResourceLimitError(SynapseError):
return cs_error(
self.msg,
self.errcode,
- admin_uri=self.admin_uri,
+ admin_contact=self.admin_contact,
limit_type=self.limit_type
)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 186831e118..eed8c67e6a 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -226,7 +226,7 @@ class Filtering(object):
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
format_checker=FormatChecker())
except jsonschema.ValidationError as e:
- raise SynapseError(400, e.message)
+ raise SynapseError(400, str(e))
class FilterCollection(object):
@@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False
)
self.event_fields = filter_json.get("event_fields", [])
+ self.event_format = filter_json.get("event_format", "client")
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 71347912f1..6d9f1ca0ef 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -64,7 +64,7 @@ class ConsentURIBuilder(object):
"""
mac = hmac.new(
key=self._hmac_secret,
- msg=user_id,
+ msg=user_id.encode('ascii'),
digestmod=sha256,
).hexdigest()
consent_uri = "%s_matrix/consent?%s" % (
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 3b6b9368b8..c3afcc573b 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -24,7 +24,7 @@ try:
python_dependencies.check_requirements()
except python_dependencies.MissingRequirementError as e:
message = "\n".join([
- "Missing Requirement: %s" % (e.message,),
+ "Missing Requirement: %s" % (str(e),),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 3348a8ec6d..8559e141af 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -51,10 +51,7 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = AppserviceSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -139,7 +136,7 @@ def start(config_options):
"Synapse appservice", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
@@ -175,7 +172,6 @@ def start(config_options):
def start():
ps.get_datastore().start_profiling()
- ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index ab79a45646..76aed8c60a 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -74,10 +74,7 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = ClientReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -156,7 +153,7 @@ def start(config_options):
"Synapse client reader", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.client_reader"
@@ -184,7 +181,6 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def start():
- ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index a34c89fa99..9060ab14f6 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -90,10 +90,7 @@ class EventCreatorSlavedStore(
class EventCreatorServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = EventCreatorSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -172,7 +169,7 @@ def start(config_options):
"Synapse event creator", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.event_creator"
@@ -202,7 +199,6 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def start():
- ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 7d8105778d..228a297fb8 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -72,10 +72,7 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = FederationReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -143,7 +140,7 @@ def start(config_options):
"Synapse federation reader", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
@@ -171,7 +168,6 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def start():
- ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index d59007099b..e9a99d76e1 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -78,10 +78,7 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = FederationSenderSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -163,7 +160,7 @@ def start(config_options):
"Synapse federation sender", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_sender"
@@ -204,7 +201,6 @@ def start(config_options):
def start():
ps.get_datastore().start_profiling()
- ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-federation-sender", config)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 8d484c1cd4..fc4b25de1c 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -148,10 +148,7 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = FrontendProxySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -231,7 +228,7 @@ def start(config_options):
"Synapse frontend proxy", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.frontend_proxy"
@@ -261,7 +258,6 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def start():
- ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 005921dcf7..e3f0d99a3f 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
-from synapse.storage import are_all_users_on_domain
+from synapse.storage import DataStore, are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.util.caches import CACHE_SIZE_FACTOR
@@ -111,6 +111,8 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
@@ -299,12 +301,16 @@ class SynapseHomeServer(HomeServer):
try:
database_engine.check_database(db_conn.cursor())
except IncorrectDatabaseSetup as e:
- quit_with_error(e.message)
+ quit_with_error(str(e))
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
+registered_reserved_users_mau_gauge = Gauge(
+ "synapse_admin_mau:registered_reserved_users",
+ "Registered users with reserved threepids"
+)
def setup(config_options):
@@ -322,7 +328,7 @@ def setup(config_options):
config_options,
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
if not config:
@@ -356,13 +362,13 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
- db_conn = hs.get_db_conn(run_new_connection=False)
- prepare_database(db_conn, database_engine, config=config)
- database_engine.on_new_connection(db_conn)
+ with hs.get_db_conn(run_new_connection=False) as db_conn:
+ prepare_database(db_conn, database_engine, config=config)
+ database_engine.on_new_connection(db_conn)
- hs.run_startup_checks(db_conn, database_engine)
+ hs.run_startup_checks(db_conn, database_engine)
- db_conn.commit()
+ db_conn.commit()
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"
@@ -378,10 +384,8 @@ def setup(config_options):
def start():
hs.get_pusherpool().start()
- hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
- hs.get_federation_client().start_get_pdu_cache()
reactor.callWhenRunning(start)
@@ -451,6 +455,10 @@ def run(hs):
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
stats["total_users"] = yield hs.get_datastore().count_all_users()
total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
@@ -529,10 +537,14 @@ def run(hs):
@defer.inlineCallbacks
def generate_monthly_active_users():
- count = 0
+ current_mau_count = 0
+ reserved_count = 0
+ store = hs.get_datastore()
if hs.config.limit_usage_by_mau:
- count = yield hs.get_datastore().get_monthly_active_count()
- current_mau_gauge.set(float(count))
+ current_mau_count = yield store.get_monthly_active_count()
+ reserved_count = yield store.get_registered_reserved_users_count()
+ current_mau_gauge.set(float(current_mau_count))
+ registered_reserved_users_mau_gauge.set(float(reserved_count))
max_mau_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index fd1f6cbf7e..acc0487adc 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -60,10 +60,7 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = MediaRepositorySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -136,7 +133,7 @@ def start(config_options):
"Synapse media repository", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
@@ -171,7 +168,6 @@ def start(config_options):
ss.start_listening(config.worker_listeners)
def start():
- ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index a4fc7e91fa..630dcda478 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -78,10 +78,7 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = PusherSlaveStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = PusherSlaveStore
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
@@ -194,7 +191,7 @@ def start(config_options):
"Synapse pusher", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.pusher"
@@ -231,7 +228,6 @@ def start(config_options):
def start():
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
- ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 27e1998660..9a7fc6ee9d 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -249,10 +249,7 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = SynchrotronSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -413,7 +410,7 @@ def start(config_options):
"Synapse synchrotron", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.synchrotron"
@@ -438,7 +435,6 @@ def start(config_options):
def start():
ss.get_datastore().start_profiling()
- ss.get_state_handler().start_caching()
reactor.callWhenRunning(start)
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
deleted file mode 100755
index d658f967ba..0000000000
--- a/synapse/app/synctl.py
+++ /dev/null
@@ -1,284 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-import collections
-import errno
-import glob
-import os
-import os.path
-import signal
-import subprocess
-import sys
-import time
-
-from six import iteritems
-
-import yaml
-
-SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
-
-GREEN = "\x1b[1;32m"
-YELLOW = "\x1b[1;33m"
-RED = "\x1b[1;31m"
-NORMAL = "\x1b[m"
-
-
-def pid_running(pid):
- try:
- os.kill(pid, 0)
- return True
- except OSError as err:
- if err.errno == errno.EPERM:
- return True
- return False
-
-
-def write(message, colour=NORMAL, stream=sys.stdout):
- if colour == NORMAL:
- stream.write(message + "\n")
- else:
- stream.write(colour + message + NORMAL + "\n")
-
-
-def abort(message, colour=RED, stream=sys.stderr):
- write(message, colour, stream)
- sys.exit(1)
-
-
-def start(configfile):
- write("Starting ...")
- args = SYNAPSE
- args.extend(["--daemonize", "-c", configfile])
-
- try:
- subprocess.check_call(args)
- write("started synapse.app.homeserver(%r)" %
- (configfile,), colour=GREEN)
- except subprocess.CalledProcessError as e:
- write(
- "error starting (exit code: %d); see above for logs" % e.returncode,
- colour=RED,
- )
-
-
-def start_worker(app, configfile, worker_configfile):
- args = [
- "python", "-B",
- "-m", app,
- "-c", configfile,
- "-c", worker_configfile
- ]
-
- try:
- subprocess.check_call(args)
- write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
- except subprocess.CalledProcessError as e:
- write(
- "error starting %s(%r) (exit code: %d); see above for logs" % (
- app, worker_configfile, e.returncode,
- ),
- colour=RED,
- )
-
-
-def stop(pidfile, app):
- if os.path.exists(pidfile):
- pid = int(open(pidfile).read())
- try:
- os.kill(pid, signal.SIGTERM)
- write("stopped %s" % (app,), colour=GREEN)
- except OSError as err:
- if err.errno == errno.ESRCH:
- write("%s not running" % (app,), colour=YELLOW)
- elif err.errno == errno.EPERM:
- abort("Cannot stop %s: Operation not permitted" % (app,))
- else:
- abort("Cannot stop %s: Unknown error" % (app,))
-
-
-Worker = collections.namedtuple("Worker", [
- "app", "configfile", "pidfile", "cache_factor"
-])
-
-
-def main():
-
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "action",
- choices=["start", "stop", "restart"],
- help="whether to start, stop or restart the synapse",
- )
- parser.add_argument(
- "configfile",
- nargs="?",
- default="homeserver.yaml",
- help="the homeserver config file, defaults to homeserver.yaml",
- )
- parser.add_argument(
- "-w", "--worker",
- metavar="WORKERCONFIG",
- help="start or stop a single worker",
- )
- parser.add_argument(
- "-a", "--all-processes",
- metavar="WORKERCONFIGDIR",
- help="start or stop all the workers in the given directory"
- " and the main synapse process",
- )
-
- options = parser.parse_args()
-
- if options.worker and options.all_processes:
- write(
- 'Cannot use "--worker" with "--all-processes"',
- stream=sys.stderr
- )
- sys.exit(1)
-
- configfile = options.configfile
-
- if not os.path.exists(configfile):
- write(
- "No config file found\n"
- "To generate a config file, run '%s -c %s --generate-config"
- " --server-name=<server name>'\n" % (
- " ".join(SYNAPSE), options.configfile
- ),
- stream=sys.stderr,
- )
- sys.exit(1)
-
- with open(configfile) as stream:
- config = yaml.load(stream)
-
- pidfile = config["pid_file"]
- cache_factor = config.get("synctl_cache_factor")
- start_stop_synapse = True
-
- if cache_factor:
- os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
-
- cache_factors = config.get("synctl_cache_factors", {})
- for cache_name, factor in iteritems(cache_factors):
- os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
-
- worker_configfiles = []
- if options.worker:
- start_stop_synapse = False
- worker_configfile = options.worker
- if not os.path.exists(worker_configfile):
- write(
- "No worker config found at %r" % (worker_configfile,),
- stream=sys.stderr,
- )
- sys.exit(1)
- worker_configfiles.append(worker_configfile)
-
- if options.all_processes:
- # To start the main synapse with -a you need to add a worker file
- # with worker_app == "synapse.app.homeserver"
- start_stop_synapse = False
- worker_configdir = options.all_processes
- if not os.path.isdir(worker_configdir):
- write(
- "No worker config directory found at %r" % (worker_configdir,),
- stream=sys.stderr,
- )
- sys.exit(1)
- worker_configfiles.extend(sorted(glob.glob(
- os.path.join(worker_configdir, "*.yaml")
- )))
-
- workers = []
- for worker_configfile in worker_configfiles:
- with open(worker_configfile) as stream:
- worker_config = yaml.load(stream)
- worker_app = worker_config["worker_app"]
- if worker_app == "synapse.app.homeserver":
- # We need to special case all of this to pick up options that may
- # be set in the main config file or in this worker config file.
- worker_pidfile = (
- worker_config.get("pid_file")
- or pidfile
- )
- worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
- daemonize = worker_config.get("daemonize") or config.get("daemonize")
- assert daemonize, "Main process must have daemonize set to true"
-
- # The master process doesn't support using worker_* config.
- for key in worker_config:
- if key == "worker_app": # But we allow worker_app
- continue
- assert not key.startswith("worker_"), \
- "Main process cannot use worker_* config"
- else:
- worker_pidfile = worker_config["worker_pid_file"]
- worker_daemonize = worker_config["worker_daemonize"]
- assert worker_daemonize, "In config %r: expected '%s' to be True" % (
- worker_configfile, "worker_daemonize")
- worker_cache_factor = worker_config.get("synctl_cache_factor")
- workers.append(Worker(
- worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
- ))
-
- action = options.action
-
- if action == "stop" or action == "restart":
- for worker in workers:
- stop(worker.pidfile, worker.app)
-
- if start_stop_synapse:
- stop(pidfile, "synapse.app.homeserver")
-
- # Wait for synapse to actually shutdown before starting it again
- if action == "restart":
- running_pids = []
- if start_stop_synapse and os.path.exists(pidfile):
- running_pids.append(int(open(pidfile).read()))
- for worker in workers:
- if os.path.exists(worker.pidfile):
- running_pids.append(int(open(worker.pidfile).read()))
- if len(running_pids) > 0:
- write("Waiting for process to exit before restarting...")
- for running_pid in running_pids:
- while pid_running(running_pid):
- time.sleep(0.2)
- write("All processes exited; now restarting...")
-
- if action == "start" or action == "restart":
- if start_stop_synapse:
- # Check if synapse is already running
- if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
- abort("synapse.app.homeserver already running")
- start(configfile)
-
- for worker in workers:
- if worker.cache_factor:
- os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
-
- start_worker(worker.app, configfile, worker.configfile)
-
- if cache_factor:
- os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
- else:
- os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
-
-
-if __name__ == "__main__":
- main()
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 1388a42b59..0a5f62b509 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -94,10 +94,7 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
- def setup(self):
- logger.info("Setting up.")
- self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
- logger.info("Finished setting up.")
+ DATASTORE_CLASS = UserDirectorySlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]
@@ -191,7 +188,7 @@ def start(config_options):
"Synapse user directory", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.user_dir"
@@ -232,7 +229,6 @@ def start(config_options):
def start():
ps.get_datastore().start_profiling()
- ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6980e5890e..9ccc5a80fc 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+
+from six.moves import urllib
from prometheus_client import Counter
@@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
- uri = service.url + ("/users/%s" % urllib.quote(user_id))
+ uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
@@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
- uri = service.url + ("/rooms/%s" % urllib.quote(alias))
+ uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
@@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url,
APP_SERVICE_PREFIX,
kind,
- urllib.quote(protocol)
+ urllib.parse.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
@@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
- urllib.quote(protocol)
+ urllib.parse.quote(protocol)
)
try:
info = yield self.get_json(uri, {})
@@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient):
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" %
- urllib.quote(txn_id))
+ urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index 58c97a70af..8fccf573ee 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -25,7 +25,7 @@ if __name__ == "__main__":
try:
config = HomeServerConfig.load_config("", sys.argv[3:])
except ConfigError as e:
- sys.stderr.write("\n" + e.message + "\n")
+ sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
print (getattr(config, key))
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 2fd9c48abf..b8d5690f2b 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -21,7 +21,7 @@ from .consent_config import ConsentConfig
from .database import DatabaseConfig
from .emailconfig import EmailConfig
from .groups import GroupsConfig
-from .jwt import JWTConfig
+from .jwt_config import JWTConfig
from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
diff --git a/synapse/config/jwt.py b/synapse/config/jwt_config.py
index 51e7f7e003..51e7f7e003 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt_config.py
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 3f187adfc8..e9a936118d 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -227,7 +227,22 @@ def setup_logging(config, use_worker_options=False):
#
# However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver()
+
+ def _log(event):
+
+ if "log_text" in event:
+ if event["log_text"].startswith("DNSDatagramProtocol starting on "):
+ return
+
+ if event["log_text"].startswith("(UDP Port "):
+ return
+
+ if event["log_text"].startswith("Timing out client"):
+ return
+
+ return observer(event)
+
globalLogBeginner.beginLoggingTo(
- [observer],
+ [_log],
redirectStandardIO=not config.no_redirect_stdio,
)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8eecd28e7d..c1c7c0105e 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -93,7 +93,7 @@ class ServerConfig(Config):
# Admin uri to direct users at should their instance become blocked
# due to resource constraints
- self.admin_uri = config.get("admin_uri", None)
+ self.admin_contact = config.get("admin_contact", None)
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
@@ -357,7 +357,7 @@ class ServerConfig(Config):
# Homeserver blocking
#
# How to reach the server admin, used in ResourceLimitError
- # admin_uri: 'mailto:admin@server.com'
+ # admin_contact: 'mailto:admin@server.com'
#
# Global block config
#
@@ -404,6 +404,23 @@ class ServerConfig(Config):
" service on the given port.")
+def is_threepid_reserved(config, threepid):
+ """Check the threepid against the reserved threepid config
+ Args:
+ config(ServerConfig) - to access server config attributes
+ threepid(dict) - The threepid to test for
+
+ Returns:
+ boolean Is the threepid undertest reserved_user
+ """
+
+ for tp in config.mau_limits_reserved_threepids:
+ if (threepid['medium'] == tp['medium']
+ and threepid['address'] == tp['address']):
+ return True
+ return False
+
+
def read_gc_thresholds(thresholds):
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 1a391adec1..02b76dfcfb 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -123,6 +123,6 @@ class ClientTLSOptionsFactory(object):
def get_options(self, host):
return ClientTLSOptions(
- host.decode('utf-8'),
+ host,
CertificateOptions(verify=False).getContext()
)
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index e94400b8e2..57d4665e84 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -50,7 +50,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.warn("Error getting key for %r: %s", server_name, e)
- if e.status.startswith("4"):
+ if e.status.startswith(b"4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
@@ -82,6 +82,12 @@ class SynapseKeyClientProtocol(HTTPClient):
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
+ if not isinstance(self.path, bytes):
+ self.path = self.path.encode('ascii')
+
+ if not isinstance(self.host, bytes):
+ self.host = self.host.encode('ascii')
+
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 30e2742102..d89f94c219 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,9 +16,10 @@
import hashlib
import logging
-import urllib
from collections import namedtuple
+from six.moves import urllib
+
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -40,6 +41,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyclient import fetch_server_key
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
+ LoggingContext,
PreserveLoggingContext,
preserve_fn,
run_in_background,
@@ -216,23 +218,34 @@ class Keyring(object):
servers have completed. Follows the synapse rules of logcontext
preservation.
"""
+ loop_count = 1
while True:
wait_on = [
- self.key_downloads[server_name]
+ (server_name, self.key_downloads[server_name])
for server_name in server_names
if server_name in self.key_downloads
]
- if wait_on:
- with PreserveLoggingContext():
- yield defer.DeferredList(wait_on)
- else:
+ if not wait_on:
break
+ logger.info(
+ "Waiting for existing lookups for %s to complete [loop %i]",
+ [w[0] for w in wait_on], loop_count,
+ )
+ with PreserveLoggingContext():
+ yield defer.DeferredList((w[1] for w in wait_on))
+
+ loop_count += 1
+
+ ctx = LoggingContext.current_context()
def rm(r, server_name_):
- self.key_downloads.pop(server_name_, None)
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name_)
+ self.key_downloads.pop(server_name_, None)
return r
for server_name, deferred in server_to_deferred.items():
+ logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
@@ -432,7 +445,7 @@ class Keyring(object):
# an incoming request.
query_response = yield self.client.post_json(
destination=perspective_name,
- path=b"/_matrix/key/v2/query",
+ path="/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
@@ -513,8 +526,8 @@ class Keyring(object):
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_client_options_factory,
- path=(b"/_matrix/key/v2/server/%s" % (
- urllib.quote(requested_key_id),
+ path=("/_matrix/key/v2/server/%s" % (
+ urllib.parse.quote(requested_key_id),
)).encode("ascii"),
)
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 6baeccca38..af3eee95b9 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -98,9 +98,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
- raise SynapseError(
+ raise AuthError(
403,
- "Room %r does not exist" % (event.room_id,)
+ "No create event in auth events",
)
creating_domain = get_domain_from_id(event.room_id)
@@ -155,10 +155,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
if user_level < invite_level:
raise AuthError(
- 403, (
- "You cannot issue a third party invite for %s." %
- (event.content.display_name,)
- )
+ 403, "You don't have permission to invite users",
)
else:
logger.debug("Allowing! %s", event)
@@ -305,7 +302,7 @@ def _is_membership_change_allowed(event, auth_events):
if user_level < invite_level:
raise AuthError(
- 403, "You cannot invite user %s." % target_user_id
+ 403, "You don't have permission to invite users",
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 51f9084b90..12f1eb0a3e 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -13,13 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+from distutils.util import strtobool
+
+import six
+
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
-# bugs where we accidentally share e.g. signature dicts. However, converting
-# a dict to frozen_dicts is expensive.
-USE_FROZEN_DICTS = True
+# bugs where we accidentally share e.g. signature dicts. However, converting a
+# dict to frozen_dicts is expensive.
+#
+# NOTE: This is overridden by the configuration by the Synapse worker apps, but
+# for the sake of tests, it is set here while it cannot be configured on the
+# homeserver object itself.
+USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
class _EventInternalMetadata(object):
@@ -147,6 +156,9 @@ class EventBase(object):
def items(self):
return list(self._event_dict.items())
+ def keys(self):
+ return six.iterkeys(self._event_dict)
+
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c11798093d..b7ad729c63 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from collections import namedtuple
import six
from twisted.internet import defer
+from twisted.internet.defer import DeferredList
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
+from synapse.types import get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -133,34 +136,45 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
-
- redacted_pdus = [
- prune_event(pdu)
- for pdu in pdus
- ]
-
- deferreds = self.keyring.verify_json_objects_for_server([
- (p.origin, p.get_pdu_json())
- for p in redacted_pdus
- ])
+ deferreds = _check_sigs_on_pdus(self.keyring, pdus)
ctx = logcontext.LoggingContext.current_context()
- def callback(_, pdu, redacted):
+ def callback(_, pdu):
with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
+ # let's try to distinguish between failures because the event was
+ # redacted (which are somewhat expected) vs actual ball-tampering
+ # incidents.
+ #
+ # This is just a heuristic, so we just assume that if the keys are
+ # about the same between the redacted and received events, then the
+ # received event was probably a redacted copy (but we then use our
+ # *actual* redacted copy to be on the safe side.)
+ redacted_event = prune_event(pdu)
+ if (
+ set(redacted_event.keys()) == set(pdu.keys()) and
+ set(six.iterkeys(redacted_event.content))
+ == set(six.iterkeys(pdu.content))
+ ):
+ logger.info(
+ "Event %s seems to have been redacted; using our redacted "
+ "copy",
+ pdu.event_id,
+ )
+ else:
+ logger.warning(
+ "Event %s content has been tampered, redacting",
+ pdu.event_id, pdu.get_pdu_json(),
+ )
+ return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
- return redacted
+ return prune_event(pdu)
return pdu
@@ -168,21 +182,121 @@ class FederationBase(object):
failure.trap(SynapseError)
with logcontext.PreserveLoggingContext(ctx):
logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
+ "Signature check failed for %s: %s",
+ pdu.event_id, failure.getErrorMessage(),
)
return failure
- for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+ for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
callback, errback,
- callbackArgs=[pdu, redacted],
+ callbackArgs=[pdu],
errbackArgs=[pdu],
)
return deferreds
+class PduToCheckSig(namedtuple("PduToCheckSig", [
+ "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
+])):
+ pass
+
+
+def _check_sigs_on_pdus(keyring, pdus):
+ """Check that the given events are correctly signed
+
+ Args:
+ keyring (synapse.crypto.Keyring): keyring object to do the checks
+ pdus (Collection[EventBase]): the events to be checked
+
+ Returns:
+ List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+ the signatures are valid, or fail (with a SynapseError) if not.
+ """
+
+ # (currently this is written assuming the v1 room structure; we'll probably want a
+ # separate function for checking v2 rooms)
+
+ # we want to check that the event is signed by:
+ #
+ # (a) the server which created the event_id
+ #
+ # (b) the sender's server.
+ #
+ # - except in the case of invites created from a 3pid invite, which are exempt
+ # from this check, because the sender has to match that of the original 3pid
+ # invite, but the event may come from a different HS, for reasons that I don't
+ # entirely grok (why do the senders have to match? and if they do, why doesn't the
+ # joining server ask the inviting server to do the switcheroo with
+ # exchange_third_party_invite?).
+ #
+ # That's pretty awful, since redacting such an invite will render it invalid
+ # (because it will then look like a regular invite without a valid signature),
+ # and signatures are *supposed* to be valid whether or not an event has been
+ # redacted. But this isn't the worst of the ways that 3pid invites are broken.
+ #
+ # let's start by getting the domain for each pdu, and flattening the event back
+ # to JSON.
+ pdus_to_check = [
+ PduToCheckSig(
+ pdu=p,
+ redacted_pdu_json=prune_event(p).get_pdu_json(),
+ event_id_domain=get_domain_from_id(p.event_id),
+ sender_domain=get_domain_from_id(p.sender),
+ deferreds=[],
+ )
+ for p in pdus
+ ]
+
+ # first make sure that the event is signed by the event_id's domain
+ deferreds = keyring.verify_json_objects_for_server([
+ (p.event_id_domain, p.redacted_pdu_json)
+ for p in pdus_to_check
+ ])
+
+ for p, d in zip(pdus_to_check, deferreds):
+ p.deferreds.append(d)
+
+ # now let's look for events where the sender's domain is different to the
+ # event id's domain (normally only the case for joins/leaves), and add additional
+ # checks.
+ pdus_to_check_sender = [
+ p for p in pdus_to_check
+ if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
+ ]
+
+ more_deferreds = keyring.verify_json_objects_for_server([
+ (p.sender_domain, p.redacted_pdu_json)
+ for p in pdus_to_check_sender
+ ])
+
+ for p, d in zip(pdus_to_check_sender, more_deferreds):
+ p.deferreds.append(d)
+
+ # replace lists of deferreds with single Deferreds
+ return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
+
+
+def _flatten_deferred_list(deferreds):
+ """Given a list of one or more deferreds, either return the single deferred, or
+ combine into a DeferredList.
+ """
+ if len(deferreds) > 1:
+ return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
+ else:
+ assert len(deferreds) == 1
+ return deferreds[0]
+
+
+def _is_invite_via_3pid(event):
+ return (
+ event.type == EventTypes.Member
+ and event.membership == Membership.INVITE
+ and "third_party_invite" in event.content
+ )
+
+
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9f3c2d352..d05ed91d64 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -66,6 +66,14 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
+ self._get_pdu_cache = ExpiringCache(
+ cache_name="get_pdu_cache",
+ clock=self._clock,
+ max_len=1000,
+ expiry_ms=120 * 1000,
+ reset_expiry_on_get=False,
+ )
+
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
@@ -82,17 +90,6 @@ class FederationClient(FederationBase):
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
- def start_get_pdu_cache(self):
- self._get_pdu_cache = ExpiringCache(
- cache_name="get_pdu_cache",
- clock=self._clock,
- max_len=1000,
- expiry_ms=120 * 1000,
- reset_expiry_on_get=False,
- )
-
- self._get_pdu_cache.start()
-
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False, ignore_backoff=False):
@@ -212,8 +209,6 @@ class FederationClient(FederationBase):
Will attempt to get the PDU from each destination in the list until
one succeeds.
- This will persist the PDU locally upon receipt.
-
Args:
destinations (list): Which home servers to query
event_id (str): event to fetch
@@ -229,10 +224,9 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event.
- if self._get_pdu_cache:
- ev = self._get_pdu_cache.get(event_id)
- if ev:
- defer.returnValue(ev)
+ ev = self._get_pdu_cache.get(event_id)
+ if ev:
+ defer.returnValue(ev)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
@@ -271,10 +265,10 @@ class FederationClient(FederationBase):
event_id, destination, e,
)
except NotRetryingDestination as e:
- logger.info(e.message)
+ logger.info(str(e))
continue
except FederationDeniedError as e:
- logger.info(e.message)
+ logger.info(str(e))
continue
except Exception as e:
pdu_attempts[destination] = now
@@ -285,7 +279,7 @@ class FederationClient(FederationBase):
)
continue
- if self._get_pdu_cache is not None and signed_pdu:
+ if signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
defer.returnValue(signed_pdu)
@@ -293,8 +287,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the `current` state PDUs for a given room from
- a remote home server.
+ """Requests all of the room state at a given event from a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
@@ -302,9 +295,10 @@ class FederationClient(FederationBase):
event_id (str): The id of the event we want the state at.
Returns:
- Deferred: Results in a list of PDUs.
+ Deferred[Tuple[List[EventBase], List[EventBase]]]:
+ A list of events in the state, and a list of events in the auth chain
+ for the given event.
"""
-
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
@@ -510,7 +504,7 @@ class FederationClient(FederationBase):
else:
logger.warn(
"Failed to %s via %s: %i %s",
- description, destination, e.code, e.message,
+ description, destination, e.code, e.args[0],
)
except Exception:
logger.warn(
@@ -875,7 +869,7 @@ class FederationClient(FederationBase):
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
- destination, e.message
+ destination, str(e)
)
raise RuntimeError("Failed to send to any server.")
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3e0cd294a1..819e8f7331 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -46,6 +46,7 @@ from synapse.replication.http.federation import (
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logcontext import nested_logging_context
from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in
@@ -99,7 +100,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
- def on_incoming_transaction(self, transaction_data):
+ def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -108,34 +109,33 @@ class FederationServer(FederationBase):
if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id")
- if not transaction.origin:
- raise Exception("Transaction missing origin")
logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
- (transaction.origin, transaction.transaction_id),
+ (origin, transaction.transaction_id),
)):
result = yield self._handle_incoming_transaction(
- transaction, request_time,
+ origin, transaction, request_time,
)
defer.returnValue(result)
@defer.inlineCallbacks
- def _handle_incoming_transaction(self, transaction, request_time):
+ def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
+ origin (unicode): the server making the request
transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at
Returns:
Deferred[(int, object)]: http response code and body
"""
- response = yield self.transaction_actions.have_responded(transaction)
+ response = yield self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
@@ -149,7 +149,7 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus))
- origin_host, _ = parse_server_name(transaction.origin)
+ origin_host, _ = parse_server_name(origin)
pdus_by_room = {}
@@ -188,21 +188,22 @@ class FederationServer(FederationBase):
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
- try:
- yield self._handle_received_pdu(
- transaction.origin, pdu
- )
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warn("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s: %s",
- event_id, f.getTraceback().rstrip(),
- )
+ with nested_logging_context(event_id):
+ try:
+ yield self._handle_received_pdu(
+ origin, pdu
+ )
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warn("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ pdu_results[event_id] = {"error": str(e)}
+ logger.error(
+ "Failed to handle PDU %s: %s",
+ event_id, f.getTraceback().rstrip(),
+ )
yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
@@ -212,7 +213,7 @@ class FederationServer(FederationBase):
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
yield self.received_edu(
- transaction.origin,
+ origin,
edu.edu_type,
edu.content
)
@@ -224,6 +225,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(
+ origin,
transaction,
200, response
)
@@ -618,7 +620,7 @@ class FederationServer(FederationBase):
)
yield self.handler.on_receive_pdu(
- origin, pdu, get_missing=True, sent_to_us_directly=True,
+ origin, pdu, sent_to_us_directly=True,
)
def __str__(self):
@@ -838,9 +840,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
)
return self._send_edu(
- edu_type=edu_type,
- origin=origin,
- content=content,
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
)
def on_query(self, query_type, args):
@@ -851,6 +853,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return handler(args)
return self._get_query_client(
- query_type=query_type,
- args=args,
+ query_type=query_type,
+ args=args,
)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 9146215c21..74ffd13b4f 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -36,7 +36,7 @@ class TransactionActions(object):
self.store = datastore
@log_function
- def have_responded(self, transaction):
+ def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and
origin?
@@ -50,11 +50,11 @@ class TransactionActions(object):
"transaction_id")
return self.store.get_received_txn_response(
- transaction.transaction_id, transaction.origin
+ transaction.transaction_id, origin
)
@log_function
- def set_response(self, transaction, code, response):
+ def set_response(self, origin, transaction, code, response):
""" Persist how we responded to a transaction.
Returns:
@@ -66,7 +66,7 @@ class TransactionActions(object):
return self.store.set_received_txn_response(
transaction.transaction_id,
- transaction.origin,
+ origin,
code,
response,
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0bb468385d..6f5995735a 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -32,7 +32,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
-from six import iteritems, itervalues
+from six import iteritems
from sortedcontainers import SortedDict
@@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object):
user_ids = set(
user_id
- for uids in itervalues(self.presence_changed)
+ for uids in self.presence_changed.values()
for user_id in uids
)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 94d7423d01..98b5950800 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -137,26 +137,6 @@ class TransactionQueue(object):
self._processing_pending_presence = False
- def can_send_to(self, destination):
- """Can we send messages to the given server?
-
- We can't send messages to ourselves. If we are running on localhost
- then we can only federation with other servers running on localhost.
- Otherwise we only federate with servers on a public domain.
-
- Args:
- destination(str): The server we are possibly trying to send to.
- Returns:
- bool: True if we can send to the server.
- """
-
- if destination == self.server_name:
- return False
- if self.server_name.startswith("localhost"):
- return destination.startswith("localhost")
- else:
- return not destination.startswith("localhost")
-
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
@@ -279,10 +259,7 @@ class TransactionQueue(object):
self._order += 1
destinations = set(destinations)
- destinations = set(
- dest for dest in destinations if self.can_send_to(dest)
- )
-
+ destinations.discard(self.server_name)
logger.debug("Sending to: %s", str(destinations))
if not destinations:
@@ -358,7 +335,7 @@ class TransactionQueue(object):
for destinations, states in hosts_and_states:
for destination in destinations:
- if not self.can_send_to(destination):
+ if destination == self.server_name:
continue
self.pending_presence_by_dest.setdefault(
@@ -377,7 +354,8 @@ class TransactionQueue(object):
content=content,
)
- if not self.can_send_to(destination):
+ if destination == self.server_name:
+ logger.info("Not sending EDU to ourselves")
return
sent_edus_counter.inc()
@@ -392,10 +370,8 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination)
def send_device_messages(self, destination):
- if destination == self.server_name or destination == "localhost":
- return
-
- if not self.can_send_to(destination):
+ if destination == self.server_name:
+ logger.info("Not sending device update to ourselves")
return
self._attempt_new_transaction(destination)
@@ -463,7 +439,19 @@ class TransactionQueue(object):
# pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+
+ # We can only include at most 50 PDUs per transactions
+ pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
+ if leftover_pdus:
+ self.pending_pdus_by_dest[destination] = leftover_pdus
+
pending_edus = self.pending_edus_by_dest.pop(destination, [])
+
+ # We can only include at most 100 EDUs per transactions
+ pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
+ if leftover_edus:
+ self.pending_edus_by_dest[destination] = leftover_edus
+
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1054441ca5..2ab973d6c8 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,8 @@
# limitations under the License.
import logging
-import urllib
+
+from six.moves import urllib
from twisted.internet import defer
@@ -951,4 +952,4 @@ def _create_path(prefix, path, *args):
Returns:
str
"""
- return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
+ return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 7a993fd1cf..2f874b4838 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -90,8 +90,8 @@ class Authenticator(object):
@defer.inlineCallbacks
def authenticate_request(self, request, content):
json_request = {
- "method": request.method,
- "uri": request.uri,
+ "method": request.method.decode('ascii'),
+ "uri": request.uri.decode('ascii'),
"destination": self.server_name,
"signatures": {},
}
@@ -252,7 +252,7 @@ class BaseFederationServlet(object):
by the callback method. None if the request has already been handled.
"""
content = None
- if request.method in ["PUT", "POST"]:
+ if request.method in [b"PUT", b"POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
@@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
try:
code, response = yield self.handler.on_incoming_transaction(
- transaction_data
+ origin, transaction_data,
)
except Exception:
logger.exception("on_incoming_transaction failed")
@@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet):
return self.handler.on_context_state_request(
origin,
context,
- query.get("event_id", [None])[0],
+ parse_string_from_args(query, "event_id", None),
)
@@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
return self.handler.on_state_ids_request(
origin,
room_id,
- query.get("event_id", [None])[0],
+ parse_string_from_args(query, "event_id", None),
)
@@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/"
def on_GET(self, origin, content, query, context):
- versions = query["v"]
- limits = query["limit"]
+ versions = [x.decode('ascii') for x in query[b"v"]]
+ limit = parse_integer_from_args(query, "limit", None)
- if not limits:
+ if not limit:
return defer.succeed((400, {"error": "Did not include limit param"}))
- limit = int(limits[-1])
-
return self.handler.on_backfill_request(origin, context, versions, limit)
@@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet):
def on_GET(self, origin, content, query, query_type):
return self.handler.on_query_request(
query_type,
- {k: v[0].decode("utf-8") for k, v in query.items()}
+ {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
)
@@ -630,14 +628,14 @@ class OpenIdUserInfo(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
- token = query.get("access_token", [None])[0]
+ token = query.get(b"access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
return
- user_id = yield self.handler.on_openid_userinfo(token)
+ user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
if user_id is None:
defer.returnValue((401, {
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4a81bd2ba9..2a5eab124f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
Args:
password (unicode): Password to hash.
- stored_hash (unicode): Expected hash value.
+ stored_hash (bytes): Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
-
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
- stored_hash.encode('utf8')
+ stored_hash
)
if stored_hash:
+ if not isinstance(stored_hash, bytes):
+ stored_hash = stored_hash.encode('ascii')
+
return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(),
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ef866da1b6..18741c5fac 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -20,7 +20,14 @@ import string
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ CodeMessageException,
+ Codes,
+ NotFoundError,
+ StoreError,
+ SynapseError,
+)
from synapse.types import RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler
@@ -109,7 +116,13 @@ class DirectoryHandler(BaseHandler):
def delete_association(self, requester, user_id, room_alias):
# association deletion for human users
- can_delete = yield self._user_can_delete_alias(room_alias, user_id)
+ try:
+ can_delete = yield self._user_can_delete_alias(room_alias, user_id)
+ except StoreError as e:
+ if e.code == 404:
+ raise NotFoundError("Unknown room alias")
+ raise
+
if not can_delete:
raise AuthError(
403, "You don't have permission to delete the alias.",
@@ -320,7 +333,7 @@ class DirectoryHandler(BaseHandler):
def _user_can_delete_alias(self, alias, user_id):
creator = yield self.store.get_room_alias_creator(alias.to_string())
- if creator and creator == user_id:
+ if creator is not None and creator == user_id:
defer.returnValue(True)
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5816bf8b4f..9dc46aa15f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -330,7 +330,8 @@ class E2eKeysHandler(object):
(algorithm, key_id, ex_json, key)
)
else:
- new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+ new_keys.append((
+ algorithm, key_id, encode_canonical_json(key).decode('ascii')))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
@@ -340,7 +341,7 @@ class E2eKeysHandler(object):
def _exception_to_failure(e):
if isinstance(e, CodeMessageException):
return {
- "status": e.code, "message": e.message,
+ "status": e.code, "message": str(e),
}
if isinstance(e, NotRetryingDestination):
@@ -358,7 +359,7 @@ def _exception_to_failure(e):
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
return {
- "status": 503, "message": str(e.message),
+ "status": 503, "message": str(e),
}
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0ebf0fd188..45d955e6f5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,7 +18,6 @@
import itertools
import logging
-import sys
import six
from six import iteritems, itervalues
@@ -69,6 +68,27 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+def shortstr(iterable, maxitems=5):
+ """If iterable has maxitems or fewer, return the stringification of a list
+ containing those items.
+
+ Otherwise, return the stringification of a a list with the first maxitems items,
+ followed by "...".
+
+ Args:
+ iterable (Iterable): iterable to truncate
+ maxitems (int): number of items to return before truncating
+
+ Returns:
+ unicode
+ """
+
+ items = list(itertools.islice(iterable, maxitems + 1))
+ if len(items) <= maxitems:
+ return str(items)
+ return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]"
+
+
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
@@ -85,7 +105,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
- self.store = hs.get_datastore()
+ self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -114,9 +134,8 @@ class FederationHandler(BaseHandler):
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
@defer.inlineCallbacks
- @log_function
def on_receive_pdu(
- self, origin, pdu, get_missing=True, sent_to_us_directly=False,
+ self, origin, pdu, sent_to_us_directly=False,
):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -125,14 +144,23 @@ class FederationHandler(BaseHandler):
origin (str): server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu (FrozenEvent): received PDU
- get_missing (bool): True if we should fetch missing prev_events
+ sent_to_us_directly (bool): True if this event was pushed to us; False if
+ we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
"""
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ logger.info(
+ "[%s %s] handling received PDU: %s",
+ room_id, event_id, pdu,
+ )
+
# We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event(
- pdu.event_id,
+ event_id,
allow_none=True,
allow_rejected=True,
)
@@ -147,7 +175,7 @@ class FederationHandler(BaseHandler):
)
)
if already_seen:
- logger.debug("Already seen pdu %s", pdu.event_id)
+ logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
return
# do some initial sanity-checking of the event. In particular, make
@@ -156,6 +184,7 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
+ logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id)
raise FederationError(
"ERROR",
err.code,
@@ -165,10 +194,12 @@ class FederationHandler(BaseHandler):
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
- if pdu.room_id in self.room_queues:
- logger.info("Ignoring PDU %s for room %s from %s for now; join "
- "in progress", pdu.event_id, pdu.room_id, origin)
- self.room_queues[pdu.room_id].append((pdu, origin))
+ if room_id in self.room_queues:
+ logger.info(
+ "[%s %s] Queuing PDU from %s for now: join in progress",
+ room_id, event_id, origin,
+ )
+ self.room_queues[room_id].append((pdu, origin))
return
# If we're no longer in the room just ditch the event entirely. This
@@ -179,7 +210,7 @@ class FederationHandler(BaseHandler):
# we should check if we *are* in fact in the room. If we are then we
# can magically rejoin the room.
is_in_room = yield self.auth.check_host_in_room(
- pdu.room_id,
+ room_id,
self.server_name
)
if not is_in_room:
@@ -188,8 +219,8 @@ class FederationHandler(BaseHandler):
)
if was_in_room:
logger.info(
- "Ignoring PDU %s for room %s from %s as we've left the room!",
- pdu.event_id, pdu.room_id, origin,
+ "[%s %s] Ignoring PDU from %s as we've left the room",
+ room_id, event_id, origin,
)
defer.returnValue(None)
@@ -204,8 +235,8 @@ class FederationHandler(BaseHandler):
)
logger.debug(
- "_handle_new_pdu min_depth for %s: %d",
- pdu.room_id, min_depth
+ "[%s %s] min_depth: %d",
+ room_id, event_id, min_depth,
)
prevs = {e_id for e_id, _ in pdu.prev_events}
@@ -218,17 +249,18 @@ class FederationHandler(BaseHandler):
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
- if get_missing and prevs - seen:
+ missing_prevs = prevs - seen
+ if sent_to_us_directly and missing_prevs:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
- "Acquiring lock for room %r to fetch %d missing events: %r...",
- pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
+ "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
+ room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "Acquired lock for room %r to fetch %d missing events",
- pdu.room_id, len(prevs - seen),
+ "[%s %s] Acquired room lock to fetch %d missing prev_events",
+ room_id, event_id, len(missing_prevs),
)
yield self._get_missing_events_for_pdu(
@@ -241,69 +273,150 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
- "Found all missing prev events for %s", pdu.event_id
+ "[%s %s] Found all missing prev_events",
+ room_id, event_id,
)
- elif prevs - seen:
+ elif missing_prevs:
logger.info(
- "Not fetching %d missing events for room %r,event %s: %r...",
- len(prevs - seen), pdu.room_id, pdu.event_id,
- list(prevs - seen)[:5],
+ "[%s %s] Not recursively fetching %d missing prev_events: %s",
+ room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
)
- if sent_to_us_directly and prevs - seen:
- # If they have sent it to us directly, and the server
- # isn't telling us about the auth events that it's
- # made a message referencing, we explode
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
- elif prevs - seen:
- # Calculate the state of the previous events, and
- # de-conflict them to find the current state.
- state_groups = []
+ if prevs - seen:
+ # We've still not been able to get all of the prev_events for this event.
+ #
+ # In this case, we need to fall back to asking another server in the
+ # federation for the state at this event. That's ok provided we then
+ # resolve the state against other bits of the DAG before using it (which
+ # will ensure that you can't just take over a room by sending an event,
+ # withholding its prev_events, and declaring yourself to be an admin in
+ # the subsequent state request).
+ #
+ # Now, if we're pulling this event as a missing prev_event, then clearly
+ # this event is not going to become the only forward-extremity and we are
+ # guaranteed to resolve its state against our existing forward
+ # extremities, so that should be fine.
+ #
+ # On the other hand, if this event was pushed to us, it is possible for
+ # it to become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very bad.
+ # Further, if the event was pushed to us, there is no excuse for us not to
+ # have all the prev_events. We therefore reject any such events.
+ #
+ # XXX this really feels like it could/should be merged with the above,
+ # but there is an interaction with min_depth that I'm not really
+ # following.
+
+ if sent_to_us_directly:
+ logger.warn(
+ "[%s %s] Failed to fetch %d prev events: rejecting",
+ room_id, event_id, len(prevs - seen),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
+
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
auth_chains = set()
+ event_map = {
+ event_id: pdu,
+ }
try:
# Get the state of the events we know about
- ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
- state_groups.append(ours)
+ ours = yield self.store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ # type: list[dict[tuple[str, str], str]]
+ state_maps = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
- state, got_auth_chain = (
- yield self.federation_client.get_state_for_room(
- origin, pdu.room_id, p
- )
+ logger.info(
+ "[%s %s] Requesting state at missing prev_event %s",
+ room_id, event_id, p,
)
- auth_chains.update(got_auth_chain)
- state_group = {(x.type, x.state_key): x.event_id for x in state}
- state_groups.append(state_group)
+
+ with logcontext.nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state, got_auth_chain = (
+ yield self.federation_client.get_state_for_room(
+ origin, room_id, p,
+ )
+ )
+
+ # we want the state *after* p; get_state_for_room returns the
+ # state *before* p.
+ remote_event = yield self.federation_client.get_pdu(
+ [origin], p, outlier=True,
+ )
+
+ if remote_event is None:
+ raise Exception(
+ "Unable to get missing prev_event %s" % (p, )
+ )
+
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+
+ # XXX hrm I'm not convinced that duplicate events will compare
+ # for equality, so I'm not sure this does what the author
+ # hoped.
+ auth_chains.update(got_auth_chain)
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
# Resolve any conflicting state
+ @defer.inlineCallbacks
def fetch(ev_ids):
- return self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False
+ fetched = yield self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
)
+ # add any events we fetch here to the `event_map` so that we
+ # can use them to build the state event list below.
+ event_map.update(fetched)
+ defer.returnValue(fetched)
- room_version = yield self.store.get_room_version(pdu.room_id)
+ room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_factory(
- room_version, state_groups, {pdu.event_id: pdu}, fetch
+ room_version, state_maps, event_map, fetch,
)
- state = (yield self.store.get_events(state_map.values())).values()
+ # we need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+ state = [
+ event_map[e] for e in six.itervalues(state_map)
+ ]
auth_chain = list(auth_chains)
except Exception:
+ logger.warn(
+ "[%s %s] Error attempting to resolve state at missing "
+ "prev_events",
+ room_id, event_id, exc_info=True,
+ )
raise FederationError(
"ERROR",
403,
"We can't get valid state history.",
- affected=pdu.event_id,
+ affected=event_id,
)
yield self._process_received_pdu(
@@ -322,15 +435,16 @@ class FederationHandler(BaseHandler):
prevs (set(str)): List of event ids which we are missing
min_depth (int): Minimum depth of events to return.
"""
- # We recalculate seen, since it may have changed.
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(
- pdu.room_id
- )
+ latest = yield self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -338,8 +452,8 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
- "Missing %d events for room %r pdu %s: %r...",
- len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5]
+ "[%s %s]: Requesting %d prev_events: %s",
+ room_id, event_id, len(prevs - seen), shortstr(prevs - seen)
)
# XXX: we set timeout to 10s to help workaround
@@ -360,49 +474,88 @@ class FederationHandler(BaseHandler):
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
+ #
+ # ----
+ #
+ # Update richvdh 2018/09/18: There are a number of problems with timing this
+ # request out agressively on the client side:
+ #
+ # - it plays badly with the server-side rate-limiter, which starts tarpitting you
+ # if you send too many requests at once, so you end up with the server carefully
+ # working through the backlog of your requests, which you have already timed
+ # out.
+ #
+ # - for this request in particular, we now (as of
+ # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
+ # server can't produce a plausible-looking set of prev_events - so we becone
+ # much more likely to reject the event.
+ #
+ # - contrary to what it says above, we do *not* fall back to fetching fresh state
+ # for the room if get_missing_events times out. Rather, we give up processing
+ # the PDU whose prevs we are missing, which then makes it much more likely that
+ # we'll end up back here for the *next* PDU in the list, which exacerbates the
+ # problem.
+ #
+ # - the agressive 10s timeout was introduced to deal with incoming federation
+ # requests taking 8 hours to process. It's not entirely clear why that was going
+ # on; certainly there were other issues causing traffic storms which are now
+ # resolved, and I think in any case we may be more sensible about our locking
+ # now. We're *certainly* more sensible about our logging.
+ #
+ # All that said: Let's try increasing the timout to 60s and see what happens.
missing_events = yield self.federation_client.get_missing_events(
origin,
- pdu.room_id,
+ room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
- timeout=10000,
+ timeout=60000,
)
logger.info(
- "Got %d events: %r...",
- len(missing_events), [e.event_id for e in missing_events[:5]]
+ "[%s %s]: Got %d prev_events: %s",
+ room_id, event_id, len(missing_events), shortstr(missing_events),
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
- for e in missing_events:
- logger.info("Handling found event %s", e.event_id)
- try:
- yield self.on_receive_pdu(
- origin,
- e,
- get_missing=False
- )
- except FederationError as e:
- if e.code == 403:
- logger.warn("Event %s failed history check.")
- else:
- raise
+ for ev in missing_events:
+ logger.info(
+ "[%s %s] Handling received prev_event %s",
+ room_id, event_id, ev.event_id,
+ )
+ with logcontext.nested_logging_context(ev.event_id):
+ try:
+ yield self.on_receive_pdu(
+ origin,
+ ev,
+ sent_to_us_directly=False,
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warn(
+ "[%s %s] Received prev_event %s failed history check.",
+ room_id, event_id, ev.event_id,
+ )
+ else:
+ raise
- @log_function
@defer.inlineCallbacks
- def _process_received_pdu(self, origin, pdu, state, auth_chain):
+ def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
- event = pdu
+ room_id = event.room_id
+ event_id = event.event_id
- logger.debug("Processing event: %s", event)
+ logger.debug(
+ "[%s %s] Processing event: %s",
+ room_id, event_id, event,
+ )
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
@@ -411,15 +564,16 @@ class FederationHandler(BaseHandler):
# event.
if state and auth_chain and not event.internal_metadata.is_outlier():
is_in_room = yield self.auth.check_host_in_room(
- event.room_id,
+ room_id,
self.server_name
)
else:
is_in_room = True
+
if not is_in_room:
logger.info(
- "Got event for room we're not in: %r %r",
- event.room_id, event.event_id
+ "[%s %s] Got event for room we're not in",
+ room_id, event_id,
)
try:
@@ -431,7 +585,7 @@ class FederationHandler(BaseHandler):
"ERROR",
e.code,
e.msg,
- affected=event.event_id,
+ affected=event_id,
)
else:
@@ -464,6 +618,10 @@ class FederationHandler(BaseHandler):
})
seen_ids.add(e.event_id)
+ logger.info(
+ "[%s %s] persisting newly-received auth/state events %s",
+ room_id, event_id, [e["event"].event_id for e in event_infos]
+ )
yield self._handle_new_events(origin, event_infos)
try:
@@ -480,12 +638,12 @@ class FederationHandler(BaseHandler):
affected=event.event_id,
)
- room = yield self.store.get_room(event.room_id)
+ room = yield self.store.get_room(room_id)
if not room:
try:
yield self.store.store_room(
- room_id=event.room_id,
+ room_id=room_id,
room_creator_user_id="",
is_public=False,
)
@@ -513,7 +671,7 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, event.room_id)
+ yield self.user_joined_room(user, room_id)
@log_function
@defer.inlineCallbacks
@@ -594,7 +752,7 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
- for event in events + state_events.values() + auth_events.values()
+ for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events
)
auth_events.update({
@@ -802,7 +960,7 @@ class FederationHandler(BaseHandler):
)
continue
except NotRetryingDestination as e:
- logger.info(e.message)
+ logger.info(str(e))
continue
except FederationDeniedError as e:
logger.info(e)
@@ -1027,7 +1185,8 @@ class FederationHandler(BaseHandler):
try:
logger.info("Processing queued PDU %s which was received "
"while we were joining %s", p.event_id, p.room_id)
- yield self.on_receive_pdu(origin, p)
+ with logcontext.nested_logging_context(p.event_id):
+ yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warn(
"Error handling queued PDU %s from %s: %s",
@@ -1358,7 +1517,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
- _, state = state_groups.items().pop()
+ _, state = list(state_groups.items()).pop()
results = state
if event.is_state():
@@ -1430,12 +1589,10 @@ class FederationHandler(BaseHandler):
else:
defer.returnValue(None)
- @log_function
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
@defer.inlineCallbacks
- @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None,
backfilled=False):
context = yield self._prep_event(
@@ -1444,6 +1601,9 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
+ # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
+ # hack around with a try/finally instead.
+ success = False
try:
if not event.internal_metadata.is_outlier() and not backfilled:
yield self.action_generator.handle_push_actions_for_event(
@@ -1454,15 +1614,13 @@ class FederationHandler(BaseHandler):
[(event, context)],
backfilled=backfilled,
)
- except: # noqa: E722, as we reraise the exception this is fine.
- tp, value, tb = sys.exc_info()
-
- logcontext.run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
- )
-
- six.reraise(tp, value, tb)
+ success = True
+ finally:
+ if not success:
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
defer.returnValue(context)
@@ -1475,15 +1633,22 @@ class FederationHandler(BaseHandler):
Notifies about the events where appropriate.
"""
- contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(
- self._prep_event,
+
+ @defer.inlineCallbacks
+ def prep(ev_info):
+ event = ev_info["event"]
+ with logcontext.nested_logging_context(suffix=event.event_id):
+ res = yield self._prep_event(
origin,
- ev_info["event"],
+ event,
state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"),
)
+ defer.returnValue(res)
+
+ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.run_in_background(prep, ev_info)
for ev_info in event_infos
], consumeErrors=True,
))
@@ -1635,8 +1800,8 @@ class FederationHandler(BaseHandler):
)
except AuthError as e:
logger.warn(
- "Rejecting %s because %s",
- event.event_id, e.msg
+ "[%s %s] Rejecting: %s",
+ event.room_id, event.event_id, e.msg
)
context.rejected = RejectedReason.AUTH_ERROR
@@ -1831,7 +1996,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
- new_state = self.state_handler.resolve_events(
+ new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e484061cc0..4954b23a0d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import sys
-import six
from six import iteritems, itervalues, string_types
from canonicaljson import encode_canonical_json, json
@@ -624,6 +622,9 @@ class EventCreationHandler(object):
event, context
)
+ # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
+ # hack around with a try/finally instead.
+ success = False
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
@@ -636,6 +637,7 @@ class EventCreationHandler(object):
ratelimit=ratelimit,
extra_users=extra_users,
)
+ success = True
return
yield self.persist_and_notify_client_event(
@@ -645,17 +647,16 @@ class EventCreationHandler(object):
ratelimit=ratelimit,
extra_users=extra_users,
)
- except: # noqa: E722, as we reraise the exception this is fine.
- # Ensure that we actually remove the entries in the push actions
- # staging area, if we calculated them.
- tp, value, tb = sys.exc_info()
-
- run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
- )
- six.reraise(tp, value, tb)
+ success = True
+ finally:
+ if not success:
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
@defer.inlineCallbacks
def persist_and_notify_client_event(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5170d093e3..a155b6e938 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -269,14 +269,7 @@ class PaginationHandler(object):
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
-
- if state:
- state = yield filter_events_for_client(
- self.store,
- user_id,
- state.values(),
- is_peeking=(member_event_id is None),
- )
+ state = state.values()
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 75b8b7ce6a..1dfbde84fd 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -142,10 +142,8 @@ class BaseProfileHandler(BaseHandler):
if e.code != 404:
logger.exception("Failed to get displayname")
raise
- except Exception:
- logger.exception("Failed to get displayname")
- else:
- defer.returnValue(result["displayname"])
+
+ defer.returnValue(result["displayname"])
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
@@ -199,8 +197,6 @@ class BaseProfileHandler(BaseHandler):
if e.code != 404:
logger.exception("Failed to get avatar_url")
raise
- except Exception:
- logger.exception("Failed to get avatar_url")
defer.returnValue(result["avatar_url"])
@@ -278,7 +274,7 @@ class BaseProfileHandler(BaseHandler):
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
- room_id, str(e.message)
+ room_id, str(e)
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f03ee1476b..da914c46ff 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler):
guest_access_token=None,
make_guest=False,
admin=False,
+ threepid=None,
):
"""Registers a new client on the server.
@@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler):
RegistrationError if there was a problem registering.
"""
- yield self.auth.check_auth_blocking()
+ yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
@@ -533,4 +534,5 @@ class RegistrationHandler(BaseHandler):
room_id=room_id,
remote_room_hosts=remote_room_hosts,
action="join",
+ ratelimit=False,
)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 37e41afd61..38e1737ec9 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
- if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
+ if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index f643619047..07fd3e82fc 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -583,6 +583,11 @@ class RoomMemberHandler(object):
room_id = mapping["room_id"]
servers = mapping["servers"]
+ # put the server which owns the alias at the front of the server list.
+ if room_alias.domain in servers:
+ servers.remove(room_alias.domain)
+ servers.insert(0, room_alias.domain)
+
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index c464adbd0b..0c1d52fd11 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
- b = decode_base64(batch)
+ b = decode_base64(batch).decode('ascii')
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
- global_next_batch = encode_base64("%s\n%s\n%s" % (
+ global_next_batch = encode_base64(("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token
- ))
+ )).encode('ascii'))
else:
- global_next_batch = encode_base64("%s\n%s\n%s" % (
+ global_next_batch = encode_base64(("%s\n%s\n%s" % (
"all", "", pagination_token
- ))
+ )).encode('ascii'))
for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64("%s\n%s\n%s" % (
+ group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
- ))
+ )).encode('ascii'))
allowed_events.extend(room_events)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 648debc8aa..67b8ca28c7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user
+from synapse.storage.roommember import MemberSummary
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -525,6 +526,8 @@ class SyncHandler(object):
A deferred dict describing the room summary
"""
+ # FIXME: we could/should get this from room_stats when matthew/stats lands
+
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1,
@@ -537,44 +540,67 @@ class SyncHandler(object):
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [
- (EventTypes.Member, None),
(EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''),
]
)
- member_ids = {
- state_key: event_id
- for (t, state_key), event_id in state_ids.iteritems()
- if t == EventTypes.Member
- }
+ # this is heavily cached, thus: fast.
+ details = yield self.store.get_room_summary(room_id)
+
name_id = state_ids.get((EventTypes.Name, ''))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
summary = {}
-
- # FIXME: it feels very heavy to load up every single membership event
- # just to calculate the counts.
- member_events = yield self.store.get_events(member_ids.values())
-
- joined_user_ids = []
- invited_user_ids = []
-
- for ev in member_events.values():
- if ev.content.get("membership") == Membership.JOIN:
- joined_user_ids.append(ev.state_key)
- elif ev.content.get("membership") == Membership.INVITE:
- invited_user_ids.append(ev.state_key)
+ empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
- summary["m.joined_member_count"] = len(joined_user_ids)
- summary["m.invited_member_count"] = len(invited_user_ids)
+ summary["m.joined_member_count"] = (
+ details.get(Membership.JOIN, empty_ms).count
+ )
+ summary["m.invited_member_count"] = (
+ details.get(Membership.INVITE, empty_ms).count
+ )
- if name_id or canonical_alias_id:
- defer.returnValue(summary)
+ # if the room has a name or canonical_alias set, we can skip
+ # calculating heroes. we assume that if the event has contents, it'll
+ # be a valid name or canonical_alias - i.e. we're checking that they
+ # haven't been "deleted" by blatting {} over the top.
+ if name_id:
+ name = yield self.store.get_event(name_id, allow_none=True)
+ if name and name.content:
+ defer.returnValue(summary)
+
+ if canonical_alias_id:
+ canonical_alias = yield self.store.get_event(
+ canonical_alias_id, allow_none=True,
+ )
+ if canonical_alias and canonical_alias.content:
+ defer.returnValue(summary)
+
+ joined_user_ids = [
+ r[0] for r in details.get(Membership.JOIN, empty_ms).members
+ ]
+ invited_user_ids = [
+ r[0] for r in details.get(Membership.INVITE, empty_ms).members
+ ]
+ gone_user_ids = (
+ [r[0] for r in details.get(Membership.LEAVE, empty_ms).members] +
+ [r[0] for r in details.get(Membership.BAN, empty_ms).members]
+ )
- # FIXME: order by stream ordering, not alphabetic
+ # FIXME: only build up a member_ids list for our heroes
+ member_ids = {}
+ for membership in (
+ Membership.JOIN,
+ Membership.INVITE,
+ Membership.LEAVE,
+ Membership.BAN
+ ):
+ for user_id, event_id in details.get(membership, empty_ms).members:
+ member_ids[user_id] = event_id
+ # FIXME: order by stream ordering rather than as returned by SQL
me = sync_config.user.to_string()
if (joined_user_ids or invited_user_ids):
summary['m.heroes'] = sorted(
@@ -586,7 +612,11 @@ class SyncHandler(object):
)[0:5]
else:
summary['m.heroes'] = sorted(
- [user_id for user_id in member_ids.keys() if user_id != me]
+ [
+ user_id
+ for user_id in gone_user_ids
+ if user_id != me
+ ]
)[0:5]
if not sync_config.filter_collection.lazy_load_members():
@@ -692,6 +722,13 @@ class SyncHandler(object):
}
if full_state:
+ if lazy_load_members:
+ # always make sure we LL ourselves so we know we're in the room
+ # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
+ # We only need apply this on full state syncs given we disabled
+ # LL for incr syncs in #3840.
+ types.append((EventTypes.Member, sync_config.user.to_string()))
+
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types,
@@ -719,6 +756,26 @@ class SyncHandler(object):
lazy_load_members=lazy_load_members,
)
elif batch.limited:
+ state_at_timeline_start = yield self.store.get_state_ids_for_event(
+ batch.events[0].event_id, types=types,
+ filtered_types=filtered_types,
+ )
+
+ # for now, we disable LL for gappy syncs - see
+ # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
+ # N.B. this slows down incr syncs as we are now processing way
+ # more state in the server than if we were LLing.
+ #
+ # We still have to filter timeline_start to LL entries (above) in order
+ # for _calculate_state's LL logic to work, as we have to include LL
+ # members for timeline senders in case they weren't loaded in the initial
+ # sync. We do this by (counterintuitively) by filtering timeline_start
+ # members to just be ones which were timeline senders, which then ensures
+ # all of the rest get included in the state block (if we need to know
+ # about them).
+ types = None
+ filtered_types = None
+
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token, types=types,
filtered_types=filtered_types,
@@ -729,25 +786,29 @@ class SyncHandler(object):
filtered_types=filtered_types,
)
- state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
- )
-
state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
current=current_state_ids,
+ # we have to include LL members in case LL initial sync missed them
lazy_load_members=lazy_load_members,
)
else:
state_ids = {}
if lazy_load_members:
- if types:
+ if types and batch.events:
+ # We're returning an incremental sync, with no
+ # "gap" since the previous sync, so normally there would be
+ # no state to return.
+ # But we're lazy-loading, so the client might need some more
+ # member events to understand the events in this timeline.
+ # So we fish out all the member events corresponding to the
+ # timeline here, and then dedupe any redundant ones below.
+
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
+ filtered_types=None, # we only want members!
)
if lazy_load_members and not include_redundant_members:
@@ -767,7 +828,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
- for t, event_id in state_ids.iteritems()
+ for t, event_id in iteritems(state_ids)
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
@@ -1568,6 +1629,19 @@ class SyncHandler(object):
newly_joined_room=newly_joined,
)
+ # When we join the room (or the client requests full_state), we should
+ # send down any existing tags. Usually the user won't have tags in a
+ # newly joined room, unless either a) they've joined before or b) the
+ # tag was added by synapse e.g. for server notice rooms.
+ if full_state:
+ user_id = sync_result_builder.sync_config.user.to_string()
+ tags = yield self.store.get_tags_for_room(user_id, room_id)
+
+ # If there aren't any tags, don't send the empty tags list down
+ # sync
+ if not tags:
+ tags = None
+
account_data_events = []
if tags is not None:
account_data_events.append({
@@ -1596,10 +1670,24 @@ class SyncHandler(object):
)
summary = {}
+
+ # we include a summary in room responses when we're lazy loading
+ # members (as the client otherwise doesn't have enough info to form
+ # the name itself).
if (
sync_config.filter_collection.lazy_load_members() and
(
+ # we recalulate the summary:
+ # if there are membership changes in the timeline, or
+ # if membership has changed during a gappy sync, or
+ # if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events) or
+ (
+ # XXX: this may include false positives in the form of LL
+ # members which have snuck into state
+ batch.limited and
+ any(t == EventTypes.Member for (t, k) in state)
+ ) or
since_token is None
)
):
@@ -1629,6 +1717,16 @@ class SyncHandler(object):
unread_notifications["highlight_count"] = notifs["highlight_count"]
sync_result_builder.joined.append(room_sync)
+
+ if batch.limited and since_token:
+ user_id = sync_result_builder.sync_config.user.to_string()
+ logger.info(
+ "Incremental gappy sync of %s for user %s with %d state events" % (
+ room_id,
+ user_id,
+ len(state),
+ )
+ )
elif room_builder.rtype == "archived":
room_sync = ArchivedSyncResult(
room_id=room_id,
@@ -1722,17 +1820,17 @@ def _calculate_state(
event_id_to_key = {
e: key
for key, e in itertools.chain(
- timeline_contains.items(),
- previous.items(),
- timeline_start.items(),
- current.items(),
+ iteritems(timeline_contains),
+ iteritems(previous),
+ iteritems(timeline_start),
+ iteritems(current),
)
}
- c_ids = set(e for e in current.values())
- ts_ids = set(e for e in timeline_start.values())
- p_ids = set(e for e in previous.values())
- tc_ids = set(e for e in timeline_contains.values())
+ c_ids = set(e for e in itervalues(current))
+ ts_ids = set(e for e in itervalues(timeline_start))
+ p_ids = set(e for e in itervalues(previous))
+ tc_ids = set(e for e in itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@@ -1746,7 +1844,7 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in timeline_start.iteritems()
+ e for t, e in iteritems(timeline_start)
if t[0] == EventTypes.Member
)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 2d2d3d5a0d..c610933dd4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID, get_domain_from_id
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -68,6 +69,11 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
+ # caches which room_ids changed at which serials
+ self._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", self._latest_room_serial,
+ )
+
self.clock.looping_call(
self._handle_timeouts,
5000,
@@ -218,6 +224,7 @@ class TypingHandler(object):
for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name:
+ logger.debug("sending typing update to %s", domain)
self.federation.send_edu(
destination=domain,
edu_type="m.typing",
@@ -274,19 +281,29 @@ class TypingHandler(object):
self._latest_room_serial += 1
self._room_serials[member.room_id] = self._latest_room_serial
+ self._typing_stream_change_cache.entity_has_changed(
+ member.room_id, self._latest_room_serial,
+ )
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
def get_all_typing_updates(self, last_id, current_id):
- # TODO: Work out a way to do this without scanning the entire state.
if last_id == current_id:
return []
+ changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
+ last_id,
+ )
+
+ if changed_rooms is None:
+ changed_rooms = self._room_serials
+
rows = []
- for room_id, serial in self._room_serials.items():
- if last_id < serial and serial <= current_id:
+ for room_id in changed_rooms:
+ serial = self._room_serials[room_id]
+ if last_id < serial <= current_id:
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 58ef8d3ce4..a3f9e4f67c 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -38,12 +38,12 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value
-ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>"""
return ACCESS_TOKEN_RE.sub(
- br'\1<redacted>\3',
+ r'\1<redacted>\3',
uri
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ab4fbf59b2..3d05f83b8c 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -13,24 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
-import urllib
-from six import StringIO
+from six import text_type
+from six.moves import urllib
+import treq
from canonicaljson import encode_canonical_json, json
from prometheus_client import Counter
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
-from twisted.internet import defer, protocol, reactor, ssl, task
+from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone
from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
- FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder,
HTTPConnectionPool,
PartialDownloadError,
@@ -42,7 +43,7 @@ from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint
-from synapse.util.async_helpers import add_timeout_to_deferred
+from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
@@ -83,8 +84,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
+ self.user_agent = self.user_agent.encode('ascii')
+
@defer.inlineCallbacks
- def request(self, method, uri, *args, **kwargs):
+ def request(self, method, uri, data=b'', headers=None):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
@@ -93,10 +96,10 @@ class SimpleHttpClient(object):
logger.info("Sending request %s %s", method, redact_uri(uri))
try:
- request_deferred = self.agent.request(
- method, uri, *args, **kwargs
+ request_deferred = treq.request(
+ method, uri, agent=self.agent, data=data, headers=headers
)
- add_timeout_to_deferred(
+ request_deferred = timeout_deferred(
request_deferred, 60, self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
@@ -112,7 +115,7 @@ class SimpleHttpClient(object):
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
- method, redact_uri(uri), type(e).__name__, e.message
+ method, redact_uri(uri), type(e).__name__, e.args[0]
)
raise
@@ -137,7 +140,8 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ query_bytes = urllib.parse.urlencode(
+ encode_urlencode_args(args), True).encode("utf8")
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -148,15 +152,14 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
- uri.encode("ascii"),
+ uri,
headers=Headers(actual_headers),
- bodyProducer=FileBodyProducer(StringIO(query_bytes))
+ data=query_bytes
)
- body = yield make_deferred_yieldable(readBody(response))
-
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ body = yield make_deferred_yieldable(treq.json_content(response))
+ defer.returnValue(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -191,9 +194,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
- uri.encode("ascii"),
+ uri,
headers=Headers(actual_headers),
- bodyProducer=FileBodyProducer(StringIO(json_str))
+ data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -248,7 +251,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
if len(args):
- query_bytes = urllib.urlencode(args, True)
+ query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body)
@@ -262,9 +265,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"PUT",
- uri.encode("ascii"),
+ uri,
headers=Headers(actual_headers),
- bodyProducer=FileBodyProducer(StringIO(json_str))
+ data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -293,7 +296,7 @@ class SimpleHttpClient(object):
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
- query_bytes = urllib.urlencode(args, True)
+ query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
@@ -304,7 +307,7 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
- uri.encode("ascii"),
+ uri,
headers=Headers(actual_headers),
)
@@ -339,13 +342,14 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
- url.encode("ascii"),
+ url,
headers=Headers(actual_headers),
)
resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
+ if (b'Content-Length' in resp_headers and
+ int(resp_headers[b'Content-Length']) > max_size):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -378,7 +382,12 @@ class SimpleHttpClient(object):
)
defer.returnValue(
- (length, resp_headers, response.request.absoluteURI, response.code),
+ (
+ length,
+ resp_headers,
+ response.request.absoluteURI.decode('ascii'),
+ response.code,
+ ),
)
@@ -434,12 +443,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}):
- query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
response = yield self.request(
"POST",
- url.encode("ascii"),
- bodyProducer=FileBodyProducer(StringIO(query_bytes)),
+ url,
+ data=query_bytes,
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
@@ -463,9 +472,9 @@ class SpiderEndpointFactory(object):
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
- if uri.scheme == "http":
+ if uri.scheme == b"http":
endpoint_factory = HostnameEndpoint
- elif uri.scheme == "https":
+ elif uri.scheme == b"https":
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
def endpoint_factory(reactor, host, port, **kw):
@@ -510,7 +519,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
- if isinstance(arg, unicode):
+ if isinstance(arg, text_type):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
@@ -542,26 +551,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port):
return self
-
-
-class FileBodyProducer(TwistedFileBodyProducer):
- """Workaround for https://twistedmatrix.com/trac/ticket/8473
-
- We override the pauseProducing and resumeProducing methods in twisted's
- FileBodyProducer so that they do not raise exceptions if the task has
- already completed.
- """
-
- def pauseProducing(self):
- try:
- super(FileBodyProducer, self).pauseProducing()
- except task.TaskDone:
- # task has already completed
- pass
-
- def resumeProducing(self):
- try:
- super(FileBodyProducer, self).resumeProducing()
- except task.NotPaused:
- # task was not paused (probably because it had already completed)
- pass
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index b0c9369519..91025037a3 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -108,7 +108,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
Args:
reactor: Twisted reactor.
- destination (bytes): The name of the server to connect to.
+ destination (unicode): The name of the server to connect to.
tls_client_options_factory
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
@@ -126,10 +126,17 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
+ # the SNI string should be the same as the Host header, minus the port.
+ # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
+ # the Host header and SNI should therefore be the server_name of the remote
+ # server.
+ tls_options = tls_client_options_factory.get_options(domain)
+
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
- tls_client_options_factory.get_options(host),
- HostnameEndpoint(reactor, host, port, timeout=timeout))
+ tls_options,
+ HostnameEndpoint(reactor, host, port, timeout=timeout),
+ )
default_port = 8448
if port is None:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b34bb8e31a..14b12cd1c4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,19 +17,22 @@ import cgi
import logging
import random
import sys
-import urllib
+from io import BytesIO
-from six import string_types
-from six.moves.urllib import parse as urlparse
+from six import PY3, string_types
+from six.moves import urllib
-from canonicaljson import encode_canonical_json, json
+import attr
+import treq
+from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from twisted.internet import defer, protocol, reactor
+from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
+from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, HTTPConnectionPool, readBody
+from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
from twisted.web.http_headers import Headers
import synapse.metrics
@@ -40,14 +43,12 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
-from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util import logcontext
-from synapse.util.async_helpers import add_timeout_to_deferred
+from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-outbound_logger = logging.getLogger("synapse.http.outbound")
outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
"", ["method"])
@@ -58,20 +59,119 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
+if PY3:
+ MAXINT = sys.maxsize
+else:
+ MAXINT = sys.maxint
+
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
+ self.reactor = hs.get_reactor()
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
- destination = uri.netloc
+ destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
- reactor, destination, timeout=10,
+ self.reactor, destination, timeout=10,
tls_client_options_factory=self.tls_client_options_factory
)
+_next_id = 1
+
+
+@attr.s
+class MatrixFederationRequest(object):
+ method = attr.ib()
+ """HTTP method
+ :type: str
+ """
+
+ path = attr.ib()
+ """HTTP path
+ :type: str
+ """
+
+ destination = attr.ib()
+ """The remote server to send the HTTP request to.
+ :type: str"""
+
+ json = attr.ib(default=None)
+ """JSON to send in the body.
+ :type: dict|None
+ """
+
+ json_callback = attr.ib(default=None)
+ """A callback to generate the JSON.
+ :type: func|None
+ """
+
+ query = attr.ib(default=None)
+ """Query arguments.
+ :type: dict|None
+ """
+
+ txn_id = attr.ib(default=None)
+ """Unique ID for this request (for logging)
+ :type: str|None
+ """
+
+ def __attrs_post_init__(self):
+ global _next_id
+ self.txn_id = "%s-O-%s" % (self.method, _next_id)
+ _next_id = (_next_id + 1) % (MAXINT - 1)
+
+ def get_json(self):
+ if self.json_callback:
+ return self.json_callback()
+ return self.json
+
+
+@defer.inlineCallbacks
+def _handle_json_response(reactor, timeout_sec, request, response):
+ """
+ Reads the JSON body of a response, with a timeout
+
+ Args:
+ reactor (IReactor): twisted reactor, for the timeout
+ timeout_sec (float): number of seconds to wait for response to complete
+ request (MatrixFederationRequest): the request that triggered the response
+ response (IResponse): response to the request
+
+ Returns:
+ dict: parsed JSON response
+ """
+ try:
+ check_content_type_is_json(response.headers)
+
+ d = treq.json_content(response)
+ d = timeout_deferred(
+ d,
+ timeout=timeout_sec,
+ reactor=reactor,
+ )
+
+ body = yield make_deferred_yieldable(d)
+ except Exception as e:
+ logger.warn(
+ "{%s} [%s] Error reading response: %s",
+ request.txn_id,
+ request.destination,
+ e,
+ )
+ raise
+ logger.info(
+ "{%s} [%s] Completed: %d %s",
+ request.txn_id,
+ request.destination,
+ response.code,
+ response.phrase.decode('ascii', errors='replace'),
+ )
+ defer.returnValue(body)
+
+
class MatrixFederationHttpClient(object):
"""HTTP client used to talk to other homeservers over the federation
protocol. Send client certificates and signs requests.
@@ -85,7 +185,9 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
+ reactor = hs.get_reactor()
pool = HTTPConnectionPool(reactor)
+ pool.retryAutomatically = False
pool.maxPersistentPerHost = 5
pool.cachedConnectionTimeout = 2 * 60
self.agent = Agent.usingEndpointFactory(
@@ -93,28 +195,36 @@ class MatrixFederationHttpClient(object):
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
- self.version_string = hs.version_string
- self._next_id = 1
+ self.version_string = hs.version_string.encode('ascii')
+ self.default_timeout = 60
- def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
- return urlparse.urlunparse(
- ("matrix", destination, path_bytes, param_bytes, query_bytes, "")
- )
+ def schedule(x):
+ reactor.callLater(_EPSILON, x)
+
+ self._cooperator = Cooperator(scheduler=schedule)
@defer.inlineCallbacks
- def _request(self, destination, method, path,
- body_callback, headers_dict={}, param_bytes=b"",
- query_bytes=b"", retry_on_dns_fail=True,
- timeout=None, long_retries=False,
- ignore_backoff=False,
- backoff_on_404=False):
- """ Creates and sends a request to the given server
+ def _send_request(
+ self,
+ request,
+ retry_on_dns_fail=True,
+ timeout=None,
+ long_retries=False,
+ ignore_backoff=False,
+ backoff_on_404=False
+ ):
+ """
+ Sends a request to the given server.
+
Args:
- destination (str): The remote server to send the HTTP request to.
- method (str): HTTP method
- path (str): The HTTP path
+ request (MatrixFederationRequest): details of request to be sent
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server). 60s by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
+
backoff_on_404 (bool): Back off if we get a 404
Returns:
@@ -132,38 +242,39 @@ class MatrixFederationHttpClient(object):
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
+ if timeout:
+ _sec_timeout = timeout / 1000
+ else:
+ _sec_timeout = self.default_timeout
+
if (
self.hs.config.federation_domain_whitelist is not None and
- destination not in self.hs.config.federation_domain_whitelist
+ request.destination not in self.hs.config.federation_domain_whitelist
):
- raise FederationDeniedError(destination)
+ raise FederationDeniedError(request.destination)
limiter = yield synapse.util.retryutils.get_retry_limiter(
- destination,
+ request.destination,
self.clock,
self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
)
- destination = destination.encode("ascii")
- path_bytes = path.encode("ascii")
- with limiter:
- headers_dict[b"User-Agent"] = [self.version_string]
- headers_dict[b"Host"] = [destination]
-
- url_bytes = self._create_url(
- destination, path_bytes, param_bytes, query_bytes
- )
-
- txn_id = "%s-O-%s" % (method, self._next_id)
- self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+ method = request.method
+ destination = request.destination
+ path_bytes = request.path.encode("ascii")
+ if request.query:
+ query_bytes = encode_query_args(request.query)
+ else:
+ query_bytes = b""
- outbound_logger.info(
- "{%s} [%s] Sending request: %s %s",
- txn_id, destination, method, url_bytes
- )
+ headers_dict = {
+ "User-Agent": [self.version_string],
+ "Host": [request.destination],
+ }
+ with limiter:
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
if long_retries:
@@ -171,88 +282,119 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
- http_url_bytes = urlparse.urlunparse(
- ("", "", path_bytes, param_bytes, query_bytes, "")
- )
-
- log_result = None
- try:
- while True:
- producer = None
- if body_callback:
- producer = body_callback(method, http_url_bytes, headers_dict)
-
- try:
- request_deferred = self.agent.request(
- method,
- url_bytes,
- Headers(headers_dict),
- producer
+ url = urllib.parse.urlunparse((
+ b"matrix", destination.encode("ascii"),
+ path_bytes, None, query_bytes, b"",
+ )).decode('ascii')
+
+ http_url = urllib.parse.urlunparse((
+ b"", b"",
+ path_bytes, None, query_bytes, b"",
+ )).decode('ascii')
+
+ while True:
+ try:
+ json = request.get_json()
+ if json:
+ data = encode_canonical_json(json)
+ headers_dict["Content-Type"] = ["application/json"]
+ self.sign_request(
+ destination, method, http_url, headers_dict, json
)
- add_timeout_to_deferred(
- request_deferred,
- timeout / 1000. if timeout else 60,
- self.hs.get_reactor(),
- cancelled_to_request_timed_out_error,
+ else:
+ data = None
+ self.sign_request(destination, method, http_url, headers_dict)
+
+ logger.info(
+ "{%s} [%s] Sending request: %s %s",
+ request.txn_id, destination, method, url
+ )
+
+ if data:
+ producer = FileBodyProducer(
+ BytesIO(data),
+ cooperator=self._cooperator
)
+ else:
+ producer = None
+
+ request_deferred = treq.request(
+ method,
+ url,
+ headers=Headers(headers_dict),
+ data=producer,
+ agent=self.agent,
+ reactor=self.hs.get_reactor(),
+ unbuffered=True
+ )
+
+ request_deferred = timeout_deferred(
+ request_deferred,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
+ )
+
+ with Measure(self.clock, "outbound_request"):
response = yield make_deferred_yieldable(
request_deferred,
)
- log_result = "%d %s" % (response.code, response.phrase,)
- break
- except Exception as e:
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
- logger.warn(
- "DNS Lookup failed to %s with %s",
- destination,
- e
- )
- log_result = "DNS Lookup failed to %s with %s" % (
- destination, e
- )
- raise
-
- logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s",
- txn_id,
+ break
+ except Exception as e:
+ logger.warn(
+ "{%s} [%s] Request failed: %s %s: %s",
+ request.txn_id,
+ destination,
+ method,
+ url,
+ _flatten_response_never_received(e),
+ )
+
+ if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ raise
+
+ if retries_left and not timeout:
+ if long_retries:
+ delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+ delay = min(delay, 60)
+ delay *= random.uniform(0.8, 1.4)
+ else:
+ delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+ delay = min(delay, 2)
+ delay *= random.uniform(0.8, 1.4)
+
+ logger.debug(
+ "{%s} [%s] Waiting %ss before re-sending...",
+ request.txn_id,
destination,
- method,
- url_bytes,
- _flatten_response_never_received(e),
+ delay,
)
- log_result = _flatten_response_never_received(e)
-
- if retries_left and not timeout:
- if long_retries:
- delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
- delay = min(delay, 60)
- delay *= random.uniform(0.8, 1.4)
- else:
- delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
- delay = min(delay, 2)
- delay *= random.uniform(0.8, 1.4)
-
- yield self.clock.sleep(delay)
- retries_left -= 1
- else:
- raise
- finally:
- outbound_logger.info(
- "{%s} [%s] Result: %s",
- txn_id,
- destination,
- log_result,
- )
+ yield self.clock.sleep(delay)
+ retries_left -= 1
+ else:
+ raise
+
+ logger.info(
+ "{%s} [%s] Got response headers: %d %s",
+ request.txn_id,
+ destination,
+ response.code,
+ response.phrase.decode('ascii', errors='replace'),
+ )
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
- with logcontext.PreserveLoggingContext():
- body = yield readBody(response)
+ d = treq.content(response)
+ d = timeout_deferred(
+ d,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
+ )
+ body = yield make_deferred_yieldable(d)
raise HttpResponseException(
response.code, response.phrase, body
)
@@ -297,11 +439,11 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
- auth_headers.append(bytes(
+ auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
- )
- ))
+ )).encode('ascii')
+ )
headers_dict[b"Authorization"] = auth_headers
@@ -346,38 +488,27 @@ class MatrixFederationHttpClient(object):
is not on our federation whitelist
"""
- if not json_data_callback:
- def json_data_callback():
- return data
+ request = MatrixFederationRequest(
+ method="PUT",
+ destination=destination,
+ path=path,
+ query=args,
+ json_callback=json_data_callback,
+ json=data,
+ )
- def body_callback(method, url_bytes, headers_dict):
- json_data = json_data_callback()
- self.sign_request(
- destination, method, url_bytes, headers_dict, json_data
- )
- producer = _JsonProducer(json_data)
- return producer
-
- response = yield self._request(
- destination,
- "PUT",
- path,
- body_callback=body_callback,
- headers_dict={"Content-Type": ["application/json"]},
- query_bytes=encode_query_args(args),
+ response = yield self._send_request(
+ request,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
backoff_on_404=backoff_on_404,
)
- if 200 <= response.code < 300:
- # We need to update the transactions table to say it was sent?
- check_content_type_is_json(response.headers)
-
- with logcontext.PreserveLoggingContext():
- body = yield readBody(response)
- defer.returnValue(json.loads(body))
+ body = yield _handle_json_response(
+ self.hs.get_reactor(), self.default_timeout, request, response,
+ )
+ defer.returnValue(body)
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
@@ -411,32 +542,30 @@ class MatrixFederationHttpClient(object):
is not on our federation whitelist
"""
- def body_callback(method, url_bytes, headers_dict):
- self.sign_request(
- destination, method, url_bytes, headers_dict, data
- )
- return _JsonProducer(data)
-
- response = yield self._request(
- destination,
- "POST",
- path,
- query_bytes=encode_query_args(args),
- body_callback=body_callback,
- headers_dict={"Content-Type": ["application/json"]},
+ request = MatrixFederationRequest(
+ method="POST",
+ destination=destination,
+ path=path,
+ query=args,
+ json=data,
+ )
+
+ response = yield self._send_request(
+ request,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
- if 200 <= response.code < 300:
- # We need to update the transactions table to say it was sent?
- check_content_type_is_json(response.headers)
-
- with logcontext.PreserveLoggingContext():
- body = yield readBody(response)
+ if timeout:
+ _sec_timeout = timeout / 1000
+ else:
+ _sec_timeout = self.default_timeout
- defer.returnValue(json.loads(body))
+ body = yield _handle_json_response(
+ self.hs.get_reactor(), _sec_timeout, request, response,
+ )
+ defer.returnValue(body)
@defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
@@ -471,29 +600,24 @@ class MatrixFederationHttpClient(object):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
- def body_callback(method, url_bytes, headers_dict):
- self.sign_request(destination, method, url_bytes, headers_dict)
- return None
+ request = MatrixFederationRequest(
+ method="GET",
+ destination=destination,
+ path=path,
+ query=args,
+ )
- response = yield self._request(
- destination,
- "GET",
- path,
- query_bytes=encode_query_args(args),
- body_callback=body_callback,
+ response = yield self._send_request(
+ request,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
- if 200 <= response.code < 300:
- # We need to update the transactions table to say it was sent?
- check_content_type_is_json(response.headers)
-
- with logcontext.PreserveLoggingContext():
- body = yield readBody(response)
-
- defer.returnValue(json.loads(body))
+ body = yield _handle_json_response(
+ self.hs.get_reactor(), self.default_timeout, request, response,
+ )
+ defer.returnValue(body)
@defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False,
@@ -523,26 +647,24 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
+ request = MatrixFederationRequest(
+ method="DELETE",
+ destination=destination,
+ path=path,
+ query=args,
+ )
- response = yield self._request(
- destination,
- "DELETE",
- path,
- query_bytes=encode_query_args(args),
- headers_dict={"Content-Type": ["application/json"]},
+ response = yield self._send_request(
+ request,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
- if 200 <= response.code < 300:
- # We need to update the transactions table to say it was sent?
- check_content_type_is_json(response.headers)
-
- with logcontext.PreserveLoggingContext():
- body = yield readBody(response)
-
- defer.returnValue(json.loads(body))
+ body = yield _handle_json_response(
+ self.hs.get_reactor(), self.default_timeout, request, response,
+ )
+ defer.returnValue(body)
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
@@ -569,26 +691,15 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
+ request = MatrixFederationRequest(
+ method="GET",
+ destination=destination,
+ path=path,
+ query=args,
+ )
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, string_types):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.urlencode(encoded_args, True)
- logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
-
- def body_callback(method, url_bytes, headers_dict):
- self.sign_request(destination, method, url_bytes, headers_dict)
- return None
-
- response = yield self._request(
- destination,
- "GET",
- path,
- query_bytes=query_bytes,
- body_callback=body_callback,
+ response = yield self._send_request(
+ request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@@ -596,14 +707,25 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders())
try:
- with logcontext.PreserveLoggingContext():
- length = yield _readBodyToFile(
- response, output_stream, max_size
- )
- except Exception:
- logger.exception("Failed to download body")
+ d = _readBodyToFile(response, output_stream, max_size)
+ d.addTimeout(self.default_timeout, self.hs.get_reactor())
+ length = yield make_deferred_yieldable(d)
+ except Exception as e:
+ logger.warn(
+ "{%s} [%s] Error reading response: %s",
+ request.txn_id,
+ request.destination,
+ e,
+ )
raise
-
+ logger.info(
+ "{%s} [%s] Completed: %d %s [%d bytes]",
+ request.txn_id,
+ request.destination,
+ response.code,
+ response.phrase.decode('ascii', errors='replace'),
+ length,
+ )
defer.returnValue((length, headers))
@@ -639,30 +761,6 @@ def _readBodyToFile(response, stream, max_size):
return d
-class _JsonProducer(object):
- """ Used by the twisted http client to create the HTTP body from json
- """
- def __init__(self, jsn):
- self.reset(jsn)
-
- def reset(self, jsn):
- self.body = encode_canonical_json(jsn)
- self.length = len(self.body)
-
- def startProducing(self, consumer):
- consumer.write(self.body)
- return defer.succeed(None)
-
- def pauseProducing(self):
- pass
-
- def stopProducing(self):
- pass
-
- def resumeProducing(self):
- pass
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -693,7 +791,7 @@ def check_content_type_is_json(headers):
"No Content-Type header"
)
- c_type = c_type[0] # only the first header
+ c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RuntimeError(
@@ -711,6 +809,6 @@ def encode_query_args(args):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
- query_bytes = urllib.urlencode(encoded_args, True)
+ query_bytes = urllib.parse.urlencode(encoded_args, True)
- return query_bytes
+ return query_bytes.encode('utf8')
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 72c2654678..fedb4e6b18 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock:
_in_flight_requests.add(self)
- def stop(self, time_sec, request):
+ def stop(self, time_sec, response_code, sent_bytes):
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
@@ -179,35 +179,35 @@ class RequestMetrics(object):
)
return
- response_code = str(request.code)
+ response_code = str(response_code)
- outgoing_responses_counter.labels(request.method, response_code).inc()
+ outgoing_responses_counter.labels(self.method, response_code).inc()
- response_count.labels(request.method, self.name, tag).inc()
+ response_count.labels(self.method, self.name, tag).inc()
- response_timer.labels(request.method, self.name, tag, response_code).observe(
+ response_timer.labels(self.method, self.name, tag, response_code).observe(
time_sec - self.start
)
resource_usage = context.get_resource_usage()
- response_ru_utime.labels(request.method, self.name, tag).inc(
+ response_ru_utime.labels(self.method, self.name, tag).inc(
resource_usage.ru_utime,
)
- response_ru_stime.labels(request.method, self.name, tag).inc(
+ response_ru_stime.labels(self.method, self.name, tag).inc(
resource_usage.ru_stime,
)
- response_db_txn_count.labels(request.method, self.name, tag).inc(
+ response_db_txn_count.labels(self.method, self.name, tag).inc(
resource_usage.db_txn_count
)
- response_db_txn_duration.labels(request.method, self.name, tag).inc(
+ response_db_txn_duration.labels(self.method, self.name, tag).inc(
resource_usage.db_txn_duration_sec
)
- response_db_sched_duration.labels(request.method, self.name, tag).inc(
+ response_db_sched_duration.labels(self.method, self.name, tag).inc(
resource_usage.db_sched_duration_sec
)
- response_size.labels(request.method, self.name, tag).inc(request.sentLength)
+ response_size.labels(self.method, self.name, tag).inc(sent_bytes)
# We always call this at the end to ensure that we update the metrics
# regardless of whether a call to /metrics while the request was in
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2d5c23e673..b4b25cab19 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -84,10 +84,21 @@ def wrap_json_request_handler(h):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
- respond_with_json(
- request, code, e.error_dict(), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
+
+ # Only respond with an error response if we haven't already started
+ # writing, otherwise lets just kill the connection
+ if request.startedWriting:
+ if request.transport:
+ try:
+ request.transport.abortConnection()
+ except Exception:
+ # abortConnection throws if the connection is already closed
+ pass
+ else:
+ respond_with_json(
+ request, code, e.error_dict(), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
except Exception:
# failure.Failure() fishes the original Failure out
@@ -100,16 +111,26 @@ def wrap_json_request_handler(h):
request,
f.getTraceback().rstrip(),
)
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- )
+ # Only respond with an error response if we haven't already started
+ # writing, otherwise lets just kill the connection
+ if request.startedWriting:
+ if request.transport:
+ try:
+ request.transport.abortConnection()
+ except Exception:
+ # abortConnection throws if the connection is already closed
+ pass
+ else:
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
return wrap_async_request_handler(wrapped_request_handler)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 88ed3714f9..e508c0bd4f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -75,17 +75,35 @@ class SynapseRequest(Request):
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
- self.method,
+ self.get_method(),
self.get_redacted_uri(),
- self.clientproto,
+ self.clientproto.decode('ascii', errors='replace'),
self.site.site_tag,
)
def get_request_id(self):
- return "%s-%i" % (self.method, self.request_seq)
+ return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self):
- return redact_uri(self.uri)
+ uri = self.uri
+ if isinstance(uri, bytes):
+ uri = self.uri.decode('ascii')
+ return redact_uri(uri)
+
+ def get_method(self):
+ """Gets the method associated with the request (or placeholder if not
+ method has yet been received).
+
+ Note: This is necessary as the placeholder value in twisted is str
+ rather than bytes, so we need to sanitise `self.method`.
+
+ Returns:
+ str
+ """
+ method = self.method
+ if isinstance(method, bytes):
+ method = self.method.decode('ascii')
+ return method
def get_user_agent(self):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
@@ -116,7 +134,7 @@ class SynapseRequest(Request):
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
- requests_counter.labels(self.method,
+ requests_counter.labels(self.get_method(),
self.request_metrics.name).inc()
@contextlib.contextmanager
@@ -204,14 +222,14 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
- self.start_time, name=servlet_name, method=self.method,
+ self.start_time, name=servlet_name, method=self.get_method(),
)
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
- self.method,
+ self.get_method(),
self.get_redacted_uri()
)
@@ -277,15 +295,15 @@ class SynapseRequest(Request):
int(usage.db_txn_count),
self.sentLength,
code,
- self.method,
+ self.get_method(),
self.get_redacted_uri(),
- self.clientproto,
+ self.clientproto.decode('ascii', errors='replace'),
user_agent,
usage.evt_db_fetch_count,
)
try:
- self.request_metrics.stop(self.finish_time, self)
+ self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
@@ -305,7 +323,7 @@ class XForwardedForRequest(SynapseRequest):
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
- b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
+ b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
class SynapseRequestFactory(object):
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 550f8443f7..59900aa5d1 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -18,8 +18,11 @@ import gc
import logging
import os
import platform
+import threading
import time
+import six
+
import attr
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import REGISTRY, GaugeMetricFamily
@@ -68,7 +71,7 @@ class LaterGauge(object):
return
if isinstance(calls, dict):
- for k, v in calls.items():
+ for k, v in six.iteritems(calls):
g.add_metric(k, v)
else:
g.add_metric([], calls)
@@ -87,6 +90,109 @@ class LaterGauge(object):
all_gauges[self.name] = self
+class InFlightGauge(object):
+ """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
+ at any given time.
+
+ Each InFlightGauge will create a metric called `<name>_total` that counts
+ the number of in flight blocks, as well as a metrics for each item in the
+ given `sub_metrics` as `<name>_<sub_metric>` which will get updated by the
+ callbacks.
+
+ Args:
+ name (str)
+ desc (str)
+ labels (list[str])
+ sub_metrics (list[str]): A list of sub metrics that the callbacks
+ will update.
+ """
+
+ def __init__(self, name, desc, labels, sub_metrics):
+ self.name = name
+ self.desc = desc
+ self.labels = labels
+ self.sub_metrics = sub_metrics
+
+ # Create a class which have the sub_metrics values as attributes, which
+ # default to 0 on initialization. Used to pass to registered callbacks.
+ self._metrics_class = attr.make_class(
+ "_MetricsEntry",
+ attrs={x: attr.ib(0) for x in sub_metrics},
+ slots=True,
+ )
+
+ # Counts number of in flight blocks for a given set of label values
+ self._registrations = {}
+
+ # Protects access to _registrations
+ self._lock = threading.Lock()
+
+ self._register_with_collector()
+
+ def register(self, key, callback):
+ """Registers that we've entered a new block with labels `key`.
+
+ `callback` gets called each time the metrics are collected. The same
+ value must also be given to `unregister`.
+
+ `callback` gets called with an object that has an attribute per
+ sub_metric, which should be updated with the necessary values. Note that
+ the metrics object is shared between all callbacks registered with the
+ same key.
+
+ Note that `callback` may be called on a separate thread.
+ """
+ with self._lock:
+ self._registrations.setdefault(key, set()).add(callback)
+
+ def unregister(self, key, callback):
+ """Registers that we've exited a block with labels `key`.
+ """
+
+ with self._lock:
+ self._registrations.setdefault(key, set()).discard(callback)
+
+ def collect(self):
+ """Called by prometheus client when it reads metrics.
+
+ Note: may be called by a separate thread.
+ """
+ in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels)
+
+ metrics_by_key = {}
+
+ # We copy so that we don't mutate the list while iterating
+ with self._lock:
+ keys = list(self._registrations)
+
+ for key in keys:
+ with self._lock:
+ callbacks = set(self._registrations[key])
+
+ in_flight.add_metric(key, len(callbacks))
+
+ metrics = self._metrics_class()
+ metrics_by_key[key] = metrics
+ for callback in callbacks:
+ callback(metrics)
+
+ yield in_flight
+
+ for name in self.sub_metrics:
+ gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels)
+ for key, metrics in six.iteritems(metrics_by_key):
+ gauge.add_metric(key, getattr(metrics, name))
+ yield gauge
+
+ def _register_with_collector(self):
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
+
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
+
+
#
# Detailed CPU metrics
#
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 167167be0a..037f1c490e 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import threading
import six
@@ -23,6 +24,9 @@ from twisted.internet import defer
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+logger = logging.getLogger(__name__)
+
+
_background_process_start_count = Counter(
"synapse_background_process_start_count",
"Number of background processes started",
@@ -97,9 +101,13 @@ class _Collector(object):
labels=["name"],
)
- # We copy the dict so that it doesn't change from underneath us
+ # We copy the dict so that it doesn't change from underneath us.
+ # We also copy the process lists as that can also change
with _bg_metrics_lock:
- _background_processes_copy = dict(_background_processes)
+ _background_processes_copy = {
+ k: list(v)
+ for k, v in six.iteritems(_background_processes)
+ }
for desc, processes in six.iteritems(_background_processes_copy):
background_process_in_flight_count.add_metric(
@@ -191,6 +199,8 @@ def run_as_background_process(desc, func, *args, **kwargs):
try:
yield func(*args, **kwargs)
+ except Exception:
+ logger.exception("Background process '%s' threw an exception", desc)
finally:
proc.update_metrics()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 82f391481c..340b16ce25 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -24,13 +24,10 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken
-from synapse.util.async_helpers import (
- DeferredTimeoutError,
- ObservableDeferred,
- add_timeout_to_deferred,
-)
-from synapse.util.logcontext import PreserveLoggingContext, run_in_background
+from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -252,7 +249,10 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- run_in_background(self._notify_app_services, room_stream_id)
+ run_as_background_process(
+ "notify_app_services",
+ self._notify_app_services, room_stream_id,
+ )
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
@@ -337,7 +337,7 @@ class Notifier(object):
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
listener = user_stream.new_listener(prev_token)
- add_timeout_to_deferred(
+ listener.deferred = timeout_deferred(
listener.deferred,
(end_time - now) / 1000.,
self.hs.get_reactor(),
@@ -354,7 +354,7 @@ class Notifier(object):
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
- except DeferredTimeoutError:
+ except defer.TimeoutError:
break
except defer.CancelledError:
break
@@ -559,15 +559,16 @@ class Notifier(object):
if end_time <= now:
break
- add_timeout_to_deferred(
- listener.deferred.addTimeout,
- (end_time - now) / 1000.,
- self.hs.get_reactor(),
+ listener.deferred = timeout_deferred(
+ listener.deferred,
+ timeout=(end_time - now) / 1000.,
+ reactor=self.hs.get_reactor(),
)
+
try:
with PreserveLoggingContext():
yield listener.deferred
- except DeferredTimeoutError:
+ except defer.TimeoutError:
break
except defer.CancelledError:
break
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 81e18bcf7d..48abd5e4d6 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -15,6 +15,8 @@
# limitations under the License.
import logging
+import six
+
from prometheus_client import Counter
from twisted.internet import defer
@@ -26,6 +28,9 @@ from synapse.util.metrics import Measure
from . import push_rule_evaluator, push_tools
+if six.PY3:
+ long = int
+
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "")
@@ -96,7 +101,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
yield self._process()
@defer.inlineCallbacks
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index bfa6df7b68..1a5a10d974 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -17,10 +17,11 @@ import email.mime.multipart
import email.utils
import logging
import time
-import urllib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from six.moves import urllib
+
import bleach
import jinja2
@@ -440,7 +441,7 @@ class Mailer(object):
def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
- base_url = self.hs.config.email_riot_base_url
+ base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
@@ -474,7 +475,7 @@ class Mailer(object):
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
- urllib.urlencode(params),
+ urllib.parse.urlencode(params),
)
@@ -561,7 +562,7 @@ def _create_mxc_to_http_filter(config):
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
config.public_baseurl,
serverAndMediaId,
- urllib.urlencode(params),
+ urllib.parse.urlencode(params),
fragment or "",
)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 9c55e79ef5..d4d983b00a 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -33,31 +33,38 @@ logger = logging.getLogger(__name__)
# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
- "frozendict>=0.4": ["frozendict"],
+ "frozendict>=1": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
- "service_identity>=1.0.0": ["service_identity>=1.0.0"],
+ "service_identity>=16.0.0": ["service_identity>=16.0.0"],
"Twisted>=17.1.0": ["twisted>=17.1.0"],
-
- # We use crypto.get_elliptic_curve which is only supported in >=0.15
- "pyopenssl>=0.15": ["OpenSSL>=0.15"],
-
- "pyyaml": ["yaml"],
- "pyasn1": ["pyasn1"],
- "daemonize": ["daemonize"],
- "bcrypt": ["bcrypt>=3.1.0"],
- "pillow": ["PIL"],
- "pydenticon": ["pydenticon"],
- "sortedcontainers": ["sortedcontainers"],
- "pysaml2>=3.0.0": ["saml2>=3.0.0"],
- "pymacaroons-pynacl": ["pymacaroons"],
+ "treq>=15.1": ["treq>=15.1"],
+
+ # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
+ "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
+
+ "pyyaml>=3.11": ["yaml"],
+ "pyasn1>=0.1.9": ["pyasn1"],
+ "pyasn1-modules>=0.0.7": ["pyasn1_modules"],
+ "daemonize>=2.3.1": ["daemonize"],
+ "bcrypt>=3.1.0": ["bcrypt>=3.1.0"],
+ "pillow>=3.1.2": ["PIL"],
+ "pydenticon>=0.2": ["pydenticon"],
+ "sortedcontainers>=1.4.4": ["sortedcontainers"],
+ "pysaml2>=3.0.0": ["saml2"],
+ "pymacaroons-pynacl>=0.9.3": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
- "six": ["six"],
- "prometheus_client": ["prometheus_client"],
- "attrs": ["attr"],
+ "six>=1.10": ["six"],
+
+ # prometheus_client 0.4.0 changed the format of counter metrics
+ # (cf https://github.com/matrix-org/synapse/issues/4001)
+ "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"],
+
+ # we use attr.s(slots), which arrived in 16.0.0
+ "attrs>=16.0.0": ["attr>=16.0.0"],
"netaddr>=0.7.18": ["netaddr"],
}
@@ -78,6 +85,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": {
"affinity": ["affinity"],
},
+ "postgres": {
+ "psycopg2>=2.6": ["psycopg2"]
+ }
}
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 8206a988f7..21b8c468fa 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import six
+
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -21,6 +23,13 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+def __func__(inp):
+ if six.PY3:
+ return inp
+ else:
+ return inp.__func__
+
+
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@@ -38,14 +47,14 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max,
)
- get_device_stream_token = DataStore.get_device_stream_token.__func__
- get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
- get_devices_by_remote = DataStore.get_devices_by_remote.__func__
- _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
- _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
- mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
+ get_device_stream_token = __func__(DataStore.get_device_stream_token)
+ get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
+ get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
+ _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
+ _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
+ mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = (
- DataStore._mark_as_sent_devices_by_remote_txn.__func__
+ __func__(DataStore._mark_as_sent_devices_by_remote_txn)
)
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 74e892c104..5dc7b3fffc 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
- ["name", "conn_id"],
+ ["name"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
+ (p.name,): len(p.pending_commands) for p in connected_connections
},
)
@@ -607,9 +607,9 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
- ["name", "conn_id"],
+ ["name"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
+ (p.name,): transport_buffer_size(p) for p in connected_connections
},
)
@@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_send_buffer",
"",
- ["name", "conn_id"],
+ ["name"],
lambda: {
- (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
+ (p.name,): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
)
@@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge(
tcp_transport_kernel_read_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_read_buffer",
"",
- ["name", "conn_id"],
+ ["name"],
lambda: {
- (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
+ (p.name,): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
)
@@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge(
tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands",
"",
- ["command", "name", "conn_id"],
+ ["command", "name"],
lambda: {
- (k[0], p.name, p.conn_id): count
+ (k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge(
tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands",
"",
- ["command", "name", "conn_id"],
+ ["command", "name"],
lambda: {
- (k[0], p.name, p.conn_id): count
+ (k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 55fe701c5c..c1e626be3f 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -196,7 +196,7 @@ class Stream(object):
)
if len(rows) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behined" % (self.NAME))
+ raise Exception("stream %s has fallen behind" % (self.NAME))
else:
rows = yield self.update_function(
from_token, current_token,
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index ad536ab570..41534b8c2a 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet):
nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
- return (200, {"nonce": nonce.encode('ascii')})
+ return (200, {"nonce": nonce})
@defer.inlineCallbacks
def on_POST(self, request):
@@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet):
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
- want_mac.update(nonce)
+ want_mac.update(nonce.encode('utf8'))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
@@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
- if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
+ if not hmac.compare_digest(
+ want_mac.encode('ascii'),
+ got_mac.encode('ascii')
+ ):
raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 0f3a2e8b51..cd9b3bdbd1 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet):
is_guest = requester.is_guest
room_id = None
if is_guest:
- if "room_id" not in request.args:
+ if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
- if "room_id" in request.args:
- room_id = request.args["room_id"][0]
+ if b"room_id" in request.args:
+ room_id = request.args[b"room_id"][0].decode('ascii')
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
+ if b"timeout" in request.args:
try:
- timeout = int(request.args["timeout"][0])
+ timeout = int(request.args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
- as_client_event = "raw" not in request.args
+ as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index fd5f85b53e..3ead75cb77 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- as_client_event = "raw" not in request.args
+ as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index cb85fa1436..0010699d31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,9 @@
# limitations under the License.
import logging
-import urllib
import xml.etree.ElementTree as ET
-from six.moves.urllib import parse as urlparse
+from six.moves import urllib
from canonicaljson import json
from saml2 import BINDING_HTTP_POST, config
@@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
- relay_state = "&RelayState=" + urllib.quote(
+ relay_state = "&RelayState=" + urllib.parse.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
@@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet):
(user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava
if 'RelayState' in request.args:
- request.redirect(urllib.unquote(
+ request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' +
@@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet):
"user_id": user_id, "token": token,
"ava": saml2_auth.ava}))
elif 'RelayState' in request.args:
- request.redirect(urllib.unquote(
+ request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=not_authenticated')
finish_request(request)
@@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
+ self.cas_server_url = hs.config.cas_server_url.encode('ascii')
+ self.cas_service_url = hs.config.cas_service_url.encode('ascii')
def on_GET(self, request):
args = request.args
- if "redirectUrl" not in args:
+ if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
- client_redirect_url_param = urllib.urlencode({
- "redirectUrl": args["redirectUrl"][0]
- })
- hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
- service_param = urllib.urlencode({
- "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
- })
- request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
+ client_redirect_url_param = urllib.parse.urlencode({
+ b"redirectUrl": args[b"redirectUrl"][0]
+ }).encode('ascii')
+ hs_redirect_url = (self.cas_service_url +
+ b"/_matrix/client/api/v1/login/cas/ticket")
+ service_param = urllib.parse.urlencode({
+ b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
+ }).encode('ascii')
+ request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)
@@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- client_redirect_url = request.args["redirectUrl"][0]
+ client_redirect_url = request.args[b"redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
- "ticket": request.args["ticket"],
+ "ticket": request.args[b"ticket"][0].decode('ascii'),
"service": self.cas_service_url
}
try:
@@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet):
finish_request(request)
def add_login_token_to_redirect_url(self, url, token):
- url_parts = list(urlparse.urlparse(url))
- query = dict(urlparse.parse_qsl(url_parts[4]))
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
- url_parts[4] = urllib.urlencode(query)
- return urlparse.urlunparse(url_parts)
+ url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
+ return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 6e95d9bec2..9382b1f124 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
- raise SynapseError(400, e.message)
+ raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request)
@@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content,
)
except InvalidRuleException as e:
- raise SynapseError(400, e.message)
+ raise SynapseError(400, str(e))
before = parse_string(request, "before")
if before:
@@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
)
self.notify_user(user_id)
except InconsistentRuleException as e:
- raise SynapseError(400, e.message)
+ raise SynapseError(400, str(e))
except RuleNotFoundException as e:
- raise SynapseError(400, e.message)
+ raise SynapseError(400, str(e))
defer.returnValue((200, {}))
@@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == b'':
defer.returnValue((200, rules))
- elif path[0] == 'global':
- path = path[1:]
+ elif path[0] == b'global':
+ path = [x.decode('ascii') for x in path[1:]]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
else:
@@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path):
if len(path) < 2:
raise UnrecognizedRequestError()
- if path[0] != 'pushrules':
+ if path[0] != b'pushrules':
raise UnrecognizedRequestError()
- scope = path[1]
+ scope = path[1].decode('ascii')
path = path[2:]
if scope != 'global':
raise UnrecognizedRequestError()
@@ -203,13 +203,13 @@ def _rule_spec_from_path(path):
if len(path) == 0:
raise UnrecognizedRequestError()
- template = path[0]
+ template = path[0].decode('ascii')
path = path[1:]
if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError()
- rule_id = path[0]
+ rule_id = path[0].decode('ascii')
spec = {
'scope': scope,
@@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
- spec['attr'] = path[0]
+ spec['attr'] = path[0].decode('ascii')
return spec
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 182a68b1e2..b84f0260f2 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet):
]
for p in pushers:
- for k, v in p.items():
+ for k, v in list(p.items()):
if k not in allowed_keys:
del p[k]
@@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
profile_tag=content.get('profile_tag', ""),
)
except PusherConfigException as pce:
- raise SynapseError(400, "Config Error: " + pce.message,
+ raise SynapseError(400, "Config Error: " + str(pce),
errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 976d98387d..663934efd0 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(),
}
- if 'ts' in request.args and requester.app_service:
+ if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event(
@@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
- remote_room_hosts = request.args["server_name"]
+ remote_room_hosts = [
+ x.decode('ascii') for x in request.args[b"server_name"]
+ ]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
@@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
- as_client_event = "raw" not in request.args
- filter_bytes = parse_string(request, "filter")
+ as_client_event = b"raw" not in request.args
+ filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 62f4c3d93e..53da905eea 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -42,7 +42,11 @@ class VoipRestServlet(ClientV1RestServlet):
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
- mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
+ mac = hmac.new(
+ turnSecret.encode(),
+ msg=username.encode(),
+ digestmod=hashlib.sha1
+ )
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py
index 5e99cffbcb..dadb376b02 100644
--- a/synapse/rest/client/v1_only/register.py
+++ b/synapse/rest/client/v1_only/register.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
+from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester
@@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet):
register_json["user"].encode("utf-8")
if "user" in register_json else None
)
+ threepid = None
+ if session.get(LoginType.EMAIL_IDENTITY):
+ threepid = session["threepidCreds"]
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
- password=password
+ password=password,
+ threepid=threepid,
)
+ # Necessary due to auth checks prior to the threepid being
+ # written to the db
+ if is_threepid_reserved(self.hs.config, threepid):
+ yield self.store.upsert_monthly_active_user(user_id)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 372648cafd..37b32dd37b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Your email domain is not authorized on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Account phone numbers are not authorized on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Your email domain is not authorized on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Account phone numbers are not authorized on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2f64155d13..192f52e462 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,6 +26,7 @@ import synapse
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
+from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Your email domain is not authorized to register on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
- 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ 403,
+ "Phone numbers are not authorized to register on this server",
+ Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
- 403, "Third party identifier is not allowed",
+ 403,
+ "Third party identifiers (email/phone numbers)" +
+ " are not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
+ threepid = None
+ if auth_result:
+ threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
+ threepid=threepid,
)
+ # Necessary due to auth checks prior to the threepid being
+ # written to the db
+ if is_threepid_reserved(self.hs.config, threepid):
+ yield self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 1275baa1ba..0251146722 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
+ format_event_raw,
serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
@@ -88,7 +89,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- if "from" in request.args:
+ if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
raise SynapseError(
@@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod
def encode_response(time_now, sync_result, access_token_id, filter):
+ if filter.event_format == 'client':
+ event_formatter = format_event_for_client_v2_without_room_id
+ elif filter.event_format == 'federation':
+ event_formatter = format_event_raw
+ else:
+ raise Exception("Unknown event format %s" % (filter.event_format, ))
+
joined = SyncRestServlet.encode_joined(
- sync_result.joined, time_now, access_token_id, filter.event_fields
+ sync_result.joined, time_now, access_token_id,
+ filter.event_fields,
+ event_formatter,
)
invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id,
+ event_formatter,
)
archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
+ event_formatter,
)
return {
@@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
}
@staticmethod
- def encode_joined(rooms, time_now, token_id, event_fields):
+ def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
- all fields will be returned.
+ all fields will be returned.
+ event_formatter (func[dict]): function to convert from federation format
+ to client format
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
@@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
- room, time_now, token_id, only_fields=event_fields
+ room, time_now, token_id, joined=True, only_fields=event_fields,
+ event_formatter=event_formatter,
)
return joined
@staticmethod
- def encode_invited(rooms, time_now, token_id):
+ def encode_invited(rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet):
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
- of transaction IDs
+ of transaction IDs
+ event_formatter (func[dict]): function to convert from federation format
+ to client format
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
@@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms:
invite = serialize_event(
room.invite, time_now, token_id=token_id,
- event_format=format_event_for_client_v2_without_room_id,
+ event_format=event_formatter,
is_invite=True,
)
unsigned = dict(invite.get("unsigned", {}))
@@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited
@staticmethod
- def encode_archived(rooms, time_now, token_id, event_fields):
+ def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
- all fields will be returned.
+ all fields will be returned.
+ event_formatter (func[dict]): function to convert from federation format
+ to client format
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
@@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
- room, time_now, token_id, joined=False, only_fields=event_fields
+ room, time_now, token_id, joined=False,
+ only_fields=event_fields,
+ event_formatter=event_formatter,
)
return joined
@staticmethod
- def encode_room(room, time_now, token_id, joined=True, only_fields=None):
+ def encode_room(
+ room, time_now, token_id, joined,
+ only_fields, event_formatter,
+ ):
"""
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
+ event_formatter (func[dict]): function to convert from federation format
+ to client format
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
- # TODO(mjark): Respect formatting requirements in the filter.
return serialize_event(
event, time_now, token_id=token_id,
- event_format=format_event_for_client_v2_without_room_id,
+ event_format=event_formatter,
only_event_fields=only_fields,
)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index d9d379182e..b9b5d07677 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
- fields.pop("access_token", None)
+ fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
@@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
- fields.pop("access_token", None)
+ fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index b9ee6e1c13..38eb2ee23f 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -88,5 +88,5 @@ class LocalKey(Resource):
)
def getChild(self, name, request):
- if name == '':
+ if name == b'':
return self
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index 3491fd2118..cb5abcf826 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey
class KeyApiV2Resource(Resource):
def __init__(self, hs):
Resource.__init__(self)
- self.putChild("server", LocalKey(hs))
- self.putChild("query", RemoteKey(hs))
+ self.putChild(b"server", LocalKey(hs))
+ self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 7d67e4b064..eb8782aa6e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -103,7 +103,7 @@ class RemoteKey(Resource):
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
- query = {server: {}}
+ query = {server.decode('ascii'): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(
@@ -112,11 +112,12 @@ class RemoteKey(Resource):
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server: {key_id: arguments}}
+ query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
)
+
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request):
@@ -135,6 +136,7 @@ class RemoteKey(Resource):
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
+
store_queries = []
for server_name, key_ids in query.items():
if (
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index f255f2883f..5a426ff2f6 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource):
# servers.
# TODO: A little crude here, we could do this better.
- filename = request.path.split('/')[-1]
+ filename = request.path.decode('ascii').split('/')[-1]
# be paranoid
filename = re.sub("[^0-9A-z.-_]", "", filename)
@@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource):
# select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?)
request.setHeader(
- "Cache-Control", "public,max-age=86400,s-maxage=86400"
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
d = FileSender().beginFileTransfer(f, request)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 65f4bd2910..76e479afa3 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -15,9 +15,8 @@
import logging
import os
-import urllib
-from six.moves.urllib import parse as urlparse
+from six.moves import urllib
from twisted.internet import defer
from twisted.protocols.basic import FileSender
@@ -35,10 +34,15 @@ def parse_media_id(request):
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
+
+ if isinstance(server_name, bytes):
+ server_name = server_name.decode('utf-8')
+ media_id = media_id.decode('utf8')
+
file_name = None
if len(request.postpath) > 2:
try:
- file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
+ file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
@@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name):
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
"""
+ def _quote(x):
+ return urllib.parse.quote(x.encode("utf-8"))
+
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
+ disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii")
else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
+ disposition = (
+ "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii")
+
+ request.setHeader(b"Content-Disposition", disposition)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index fbfa85f74f..f911b120b1 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -47,12 +47,13 @@ class DownloadResource(Resource):
def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
- "Content-Security-Policy",
- "default-src 'none';"
- " script-src 'none';"
- " plugin-types application/pdf;"
- " style-src 'unsafe-inline';"
- " object-src 'self';"
+ b"Content-Security-Policy",
+ b"default-src 'none';"
+ b" script-src 'none';"
+ b" plugin-types application/pdf;"
+ b" style-src 'unsafe-inline';"
+ b" media-src 'self';"
+ b" object-src 'self';"
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 241c972070..a828ff4438 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -20,7 +20,7 @@ import logging
import os
import shutil
-from six import iteritems
+from six import PY3, iteritems
from six.moves.urllib import parse as urlparse
import twisted.internet.error
@@ -397,13 +397,13 @@ class MediaRepository(object):
yield finish()
- media_type = headers["Content-Type"][0]
+ media_type = headers[b"Content-Type"][0].decode('ascii')
time_now_ms = self.clock.time_msec()
- content_disposition = headers.get("Content-Disposition", None)
+ content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
+ _, params = cgi.parse_header(content_disposition[0].decode('ascii'),)
upload_name = None
# First check if there is a valid UTF-8 filename
@@ -419,9 +419,13 @@ class MediaRepository(object):
upload_name = upload_name_ascii
if upload_name:
- upload_name = urlparse.unquote(upload_name)
+ if PY3:
+ upload_name = urlparse.unquote(upload_name)
+ else:
+ upload_name = urlparse.unquote(upload_name.encode('ascii'))
try:
- upload_name = upload_name.decode("utf-8")
+ if isinstance(upload_name, bytes):
+ upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
@@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource):
Resource.__init__(self)
media_repo = hs.get_media_repository()
- self.putChild("upload", UploadResource(hs, media_repo))
- self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(
+
+ self.putChild(b"upload", UploadResource(hs, media_repo))
+ self.putChild(b"download", DownloadResource(hs, media_repo))
+ self.putChild(b"thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage,
))
- self.putChild("identicon", IdenticonResource())
+ self.putChild(b"identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(
+ self.putChild(b"preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))
- self.putChild("config", MediaConfigResource(hs))
+ self.putChild(b"config", MediaConfigResource(hs))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 778ef97337..af01040a38 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -79,7 +79,6 @@ class PreviewUrlResource(Resource):
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
- self._cache.start()
self._cleaner_loop = self.clock.looping_call(
self._start_expire_url_cache_data, 10 * 1000,
@@ -261,7 +260,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- jsonog = json.dumps(og)
+ jsonog = json.dumps(og).encode('utf8')
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -301,20 +300,20 @@ class PreviewUrlResource(Resource):
logger.warn("Error downloading %s: %r", url, e)
raise SynapseError(
500, "Failed to download content: %s" % (
- traceback.format_exception_only(sys.exc_type, e),
+ traceback.format_exception_only(sys.exc_info()[0], e),
),
Codes.UNKNOWN,
)
yield finish()
try:
- if "Content-Type" in headers:
- media_type = headers["Content-Type"][0]
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode('ascii')
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
- content_disposition = headers.get("Content-Disposition", None)
+ content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
download_name = None
diff --git a/synapse/server.py b/synapse/server.py
index a795643add..3e9d3d8256 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -19,6 +19,7 @@
# partial one for unit test mocking.
# Imports required for the default HomeServer() implementation
+import abc
import logging
from twisted.enterprise import adbapi
@@ -82,7 +83,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -112,6 +112,8 @@ class HomeServer(object):
config (synapse.config.homeserver.HomeserverConfig):
"""
+ __metaclass__ = abc.ABCMeta
+
DEPENDENCIES = [
'http_client',
'db_pool',
@@ -174,6 +176,11 @@ class HomeServer(object):
'room_context_handler',
]
+ # This is overridden in derived application classes
+ # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
+ # instantiated during setup() for future return by get_datastore()
+ DATASTORE_CLASS = abc.abstractproperty()
+
def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
@@ -190,13 +197,16 @@ class HomeServer(object):
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
+ self.datastore = None
+
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def setup(self):
logger.info("Setting up.")
- self.datastore = DataStore(self.get_db_conn(), self)
+ with self.get_db_conn() as conn:
+ self.datastore = self.DATASTORE_CLASS(conn, self)
logger.info("Finished setting up.")
def get_reactor(self):
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 893b9001cd..af15cba0ee 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -118,7 +118,7 @@ class ResourceLimitsServerNotices(object):
'body': event_content,
'msgtype': ServerNoticeMsgType,
'server_notice_type': ServerNoticeLimitReached,
- 'admin_uri': self._config.admin_uri,
+ 'admin_contact': self._config.admin_contact,
'limit_type': event_limit_type
}
event = yield self._server_notices_manager.send_notice(
@@ -146,11 +146,10 @@ class ResourceLimitsServerNotices(object):
user_id(str): the user in question
room_id(str): the server notices room for that user
"""
- tags = yield self._store.get_tags_for_user(user_id)
- server_notices_tags = tags.get(room_id)
+ tags = yield self._store.get_tags_for_room(user_id, room_id)
need_to_set_tag = True
- if server_notices_tags:
- if server_notices_tags.get(SERVER_NOTICE_ROOM_TAG):
+ if tags:
+ if SERVER_NOTICE_ROOM_TAG in tags:
# tag already present, nothing to do here
need_to_set_tag = False
if need_to_set_tag:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index b34970e4d1..b22495c1f9 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -95,10 +95,6 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- def start_caching(self):
- # TODO: remove this shim
- self._state_resolution_handler.start_caching()
-
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
@@ -385,6 +381,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False,
)
+ @defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
@@ -401,15 +398,17 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(
- room_version, state_set_ids, state_map,
+ new_state = yield resolve_events_with_factory(
+ room_version, state_set_ids,
+ event_map=state_map,
+ state_map_factory=self._state_map_factory
)
new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
- return new_state
+ defer.returnValue(new_state)
class StateResolutionHandler(object):
@@ -425,9 +424,6 @@ class StateResolutionHandler(object):
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
- def start_caching(self):
- logger.debug("start_caching")
-
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
@@ -437,8 +433,6 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
- self._state_cache.start()
-
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
@@ -589,31 +583,6 @@ def _make_state_cache_entry(
)
-def resolve_events_with_state_map(room_version, state_sets, state_map):
- """
- Args:
- room_version(str): Version of the room
- state_sets(list): List of dicts of (type, state_key) -> event_id,
- which are the different state groups to resolve.
- state_map(dict): a dict from event_id to event, for all events in
- state_sets.
-
- Returns
- dict[(str, str), str]:
- a map from (type, state_key) to event_id.
- """
- if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
- return v1.resolve_events_with_state_map(
- state_sets, state_map,
- )
- else:
- # This should only happen if we added a version but forgot to add it to
- # the list above.
- raise Exception(
- "No state resolution algorithm defined for version %r" % (room_version,)
- )
-
-
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""
Args:
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 3a1f7054a1..7a7157b352 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -30,34 +30,6 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
-def resolve_events_with_state_map(state_sets, state_map):
- """
- Args:
- state_sets(list): List of dicts of (type, state_key) -> event_id,
- which are the different state groups to resolve.
- state_map(dict): a dict from event_id to event, for all events in
- state_sets.
-
- Returns
- dict[(str, str), str]:
- a map from (type, state_key) to event_id.
- """
- if len(state_sets) == 1:
- return state_sets[0]
-
- unconflicted_state, conflicted_state = _seperate(
- state_sets,
- )
-
- auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
- )
-
- return _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- )
-
-
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
@@ -93,10 +65,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
+ needed_event_count = len(needed_events)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
- logger.info("Asking for %d conflicted events", len(needed_events))
+ logger.info(
+ "Asking for %d/%d conflicted events",
+ len(needed_events),
+ needed_event_count,
+ )
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
@@ -113,11 +90,16 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
)
new_needed_events = set(itervalues(auth_events))
+ new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(iterkeys(event_map))
- logger.info("Asking for %d auth events", len(new_needed_events))
+ logger.info(
+ "Asking for %d/%d auth events",
+ len(new_needed_events),
+ new_needed_event_count,
+ )
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 08dffd774f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,9 +17,10 @@ import sys
import threading
import time
-from six import iteritems, iterkeys, itervalues
+from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 8fc678fa67..9ad17b7c25 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
- self._simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={
- "user_id": user_id,
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": last_seen,
- },
- lock=False,
- )
+ try:
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+ except Exception as e:
+ # Failed to upsert, log and continue
+ logger.error("Failed to insert client IP %r: %r", entry, e)
@defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id):
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 73646da025..e06b0bc56d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
- devices = messages_by_device.keys()
+ devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index c0943ecf91..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, SQLBaseStore
+from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(json.loads(content))
+ defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
- device["device_id"]: json.loads(device["content"])
+ device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 808194236a..cfb687cb53 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore):
},
retcol="creator",
desc="get_room_alias_creator",
- allow_none=True
)
@cached(max_entries=5000)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 523b4360c3..1f1721e820 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
+ device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 8a0386c1a4..42225f8a2a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+
+ # Set the bytea output to escape, vs the default of hex
+ cursor = db_conn.cursor()
+ cursor.execute("SET bytea_output TO escape")
+
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
- cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
- cursor.close()
+
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f39c8c8461..03cedf3a75 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
-from six import iteritems
+from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util import batch_iter
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
@@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
for room_id, ev_ctx_rm in iteritems(events_by_room):
- # Work out new extremities by recursively adding and removing
- # the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
- new_latest_event_ids = yield self._calculate_new_extremeties(
+ new_latest_event_ids = yield self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
)
@@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
# No change in extremities, so no change in state
continue
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
new_forward_extremeties[room_id] = new_latest_event_ids
len_1 = (
@@ -517,44 +522,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
@defer.inlineCallbacks
- def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremeties for a room given events to
+ def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+ """Calculates the new forward extremities for a room given events to
persist.
Assumes that we are only persisting events for one room at a time.
"""
- new_latest_event_ids = set(latest_event_ids)
- # First, add all the new events to the list
- new_latest_event_ids.update(
- event.event_id for event, ctx in event_contexts
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event for event, ctx in event_contexts
if not event.internal_metadata.is_outlier() and not ctx.rejected
+ ]
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(
+ event.event_id for event in new_events
)
- # Now remove all events that are referenced by the to-be-added events
- new_latest_event_ids.difference_update(
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
e_id
- for event, ctx in event_contexts
+ for event in new_events
for e_id, _ in event.prev_events
- if not event.internal_metadata.is_outlier() and not ctx.rejected
)
- # And finally remove any events that are referenced by previously added
- # events.
- rows = yield self._simple_select_many_batch(
- table="event_edges",
- column="prev_event_id",
- iterable=list(new_latest_event_ids),
- retcols=["prev_event_id"],
- keyvalues={
- "is_state": False,
- },
- desc="_calculate_new_extremeties",
- )
+ # Finally, remove any events which are prev_events of any existing events.
+ existing_prevs = yield self._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
- new_latest_event_ids.difference_update(
- row["prev_event_id"] for row in rows
- )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _get_events_which_are_prevs(self, event_ids):
+ """Filter the supplied list of event_ids to get those which are prev_events of
+ existing (non-outlier/rejected) events.
+
+ Args:
+ event_ids (Iterable[str]): event ids to filter
+
+ Returns:
+ Deferred[List[str]]: filtered event ids
+ """
+ results = []
+
+ def _get_events(txn, batch):
+ sql = """
+ SELECT prev_event_id
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ prev_event_id IN (%s)
+ AND NOT events.outlier
+ AND rejections.event_id IS NULL
+ """ % (
+ ",".join("?" for _ in batch),
+ )
+
+ txn.execute(sql, batch)
+ results.extend(r[0] for r in txn)
+
+ for chunk in batch_iter(event_ids, 100):
+ yield self.runInteraction(
+ "_get_events_which_are_prevs",
+ _get_events,
+ chunk,
+ )
- defer.returnValue(new_latest_event_ids)
+ defer.returnValue(results)
@defer.inlineCallbacks
def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
@@ -586,10 +626,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
the new current state is only returned if we've already calculated
it.
"""
-
- if not new_latest_event_ids:
- return
-
# map from state_group to ((type, key) -> event_id) state map
state_groups_map = {}
@@ -930,6 +966,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
self._invalidate_cache_and_stream(
+ txn, self.get_room_summary, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
@@ -1220,7 +1260,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
- and isinstance(event.content["url"], basestring)
+ and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1569,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], basestring)
+ contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1886,20 +1926,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
")"
)
- # create an index on should_delete because later we'll be looking for
- # the should_delete / shouldn't_delete subsets
- txn.execute(
- "CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)",
- )
-
- # We do joins against events_to_purge for e.g. calculating state
- # groups to purge, etc., so lets make an index.
- txn.execute(
- "CREATE INDEX events_to_purge_id"
- " ON events_to_purge(event_id)",
- )
-
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1910,9 +1936,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
- max_depth = max(row[0] for row in rows)
+ max_depth = max(row[1] for row in rows)
- if max_depth <= token.topological:
+ if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -1926,19 +1952,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
should_delete_params = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
- should_delete_params += ("%:" + self.hs.hostname, )
+
+ # We include the parameter twice since we use the expression twice
+ should_delete_params += (
+ "%:" + self.hs.hostname,
+ "%:" + self.hs.hostname,
+ )
should_delete_params += (room_id, token.topological)
+ # Note that we insert events that are outliers and aren't going to be
+ # deleted, as nothing will happen to them.
txn.execute(
"INSERT INTO events_to_purge"
" SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)"
- " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+ % (
+ should_delete_expr,
should_delete_expr,
),
should_delete_params,
)
+
+ # We create the indices *after* insertion as that's a lot faster.
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
txn.execute(
"SELECT event_id, should_delete FROM events_to_purge"
)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 59822178ff..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list)
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2d5896c5b4..6ddcc909bf 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
+ defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index f547977600..a1331c1a61 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore):
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+ # XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d178f5c5ba..0fe8c8e24c 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
- # TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
@@ -148,6 +147,23 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
+ def get_registered_reserved_users_count(self):
+ """Of the reserved threepids defined in config, how many are associated
+ with registered users?
+
+ Returns:
+ Defered[int]: Number of real reserved users
+ """
+ count = 0
+ for tp in self.hs.config.mau_limits_reserved_threepids:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ count = count + 1
+ defer.returnValue(count)
+
+ @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
@@ -156,6 +172,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
"""
+ # Am consciously deciding to lock the table on the basis that is ought
+ # never be a big table and alternative approaches (batching multiple
+ # upserts into a single txn) introduced a lot of extra complexity.
+ # See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = yield self._simple_upsert(
desc="upsert_monthly_active_user",
table="monthly_active_users",
@@ -165,7 +185,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
values={
"timestamp": int(self._clock.time_msec()),
},
- lock=False,
)
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,))
@@ -200,10 +219,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
+
if self.hs.config.limit_usage_by_mau:
+ # Trial users and guests should not be included as part of MAU group
+ is_guest = yield self.is_guest(user_id)
+ if is_guest:
+ return
is_trial = yield self.is_trial_user(user_id)
if is_trial:
- # we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8443bd4c1b..c7987bfcdd 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -15,7 +15,8 @@
# limitations under the License.
import logging
-import types
+
+import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
- if isinstance(dataJson, types.BufferType):
+ if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.message,
+ r['id'], dataJson, e.args[0],
)
pass
- if isinstance(r['pushkey'], types.BufferType):
+ if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9b4e6d6aa8..0707f9a86a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -51,6 +51,12 @@ ProfileInfo = namedtuple(
"ProfileInfo", ("avatar_url", "display_name")
)
+# "members" points to a truncated list of (user_id, event_id) tuples for users of
+# a given membership type, suitable for use in calculating heroes for a room.
+# "count" points to the total numberr of users of a given membership type.
+MemberSummary = namedtuple(
+ "MemberSummary", ("members", "count")
+)
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
+ @cached(max_entries=100000)
+ def get_room_summary(self, room_id):
+ """ Get the details of a room roughly suitable for use by the room
+ summary extension to /sync. Useful when lazy loading room members.
+ Args:
+ room_id (str): The room ID to query
+ Returns:
+ Deferred[dict[str, MemberSummary]:
+ dict of membership states, pointing to a MemberSummary named tuple.
+ """
+
+ def _get_room_summary_txn(txn):
+ # first get counts.
+ # We do this all in one transaction to keep the cache small.
+ # FIXME: get rid of this when we have room_stats
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ res = {}
+ for count, membership in txn:
+ summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+ # we order by membership and then fairly arbitrarily by event_id so
+ # heroes are consistent
+ sql = """
+ SELECT m.user_id, m.membership, m.event_id
+ FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ ORDER BY
+ CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ m.event_id ASC
+ LIMIT ?
+ """
+
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ for user_id, membership, event_id in txn:
+ summary = res[to_ascii(membership)]
+ # we will always have a summary for this membership type at this
+ # point given the summary currently contains the counts.
+ members = summary.members
+ members.append((to_ascii(user_id), to_ascii(event_id)))
+
+ return res
+
+ return self.runInteraction("get_room_summary", _get_room_summary_txn)
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 4b971efdba..3f4cbd61c4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def get_state_groups_ids(self, room_id, event_ids):
+ def get_state_groups_ids(self, _room_id, event_ids):
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id (str): id of the room for these events
+ event_ids (iterable[str]): ids of the events
+
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
if not event_ids:
defer.returnValue({})
@@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
- """Get the state IDs for the given state group
+ """Get the event IDs of all the state in the given state group
Args:
state_group (int)
@@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
- The return value is a dict mapping group names to lists of events.
+ Returns:
+ Deferred[dict[int, list[EventBase]]]:
+ dict of state_group_id -> list of state events.
"""
if not event_ids:
defer.returnValue({})
@@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
member events (if True), or to exclude member events (if False)
Returns:
- dictionary state_group -> (dict of (type, state_key) -> event id)
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
results = {}
@@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- Deferred[dict[int, dict[(type, state_key), EventBase]]]
- a dictionary mapping from state group to state dictionary.
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if types is not None:
non_member_types = [t for t in types if t[0] != EventTypes.Member]
@@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- Deferred[dict[int, dict[(type, state_key), EventBase]]]
- a dictionary mapping from state group to state dictionary.
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if types:
types = frozenset(types)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 428e7fa36e..a3032cdce9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -18,14 +18,14 @@ from collections import namedtuple
import six
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.expiringcache import ExpiringCache
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple(
)
)
+SENTINEL = object()
+
class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
@@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore):
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
+ self._destination_retry_cache = ExpiringCache(
+ cache_name="get_destination_retry_timings",
+ clock=self._clock,
+ expiry_ms=5 * 60 * 1000,
+ )
+
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
@@ -95,7 +103,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], json.loads(str(result["response_json"]))
+ return result["response_code"], db_to_json(result["response_json"])
+
else:
return None
@@ -155,7 +164,7 @@ class TransactionStore(SQLBaseStore):
"""
pass
- @cached(max_entries=10000)
+ @defer.inlineCallbacks
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -166,10 +175,20 @@ class TransactionStore(SQLBaseStore):
None if not retrying
Otherwise a dict for the retry scheme
"""
- return self.runInteraction(
+
+ result = self._destination_retry_cache.get(destination, SENTINEL)
+ if result is not SENTINEL:
+ defer.returnValue(result)
+
+ result = yield self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
+ # We don't hugely care about race conditions between getting and
+ # invalidating the cache, since we time out fairly quickly anyway.
+ self._destination_retry_cache[destination] = result
+ defer.returnValue(result)
+
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
txn,
@@ -197,8 +216,7 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms
"""
- # XXX: we could chose to not bother persisting this if our cache thinks
- # this is a NOOP
+ self._destination_retry_cache.pop(destination, None)
return self.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
@@ -211,10 +229,6 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval):
self.database_engine.lock_table(txn, "destinations")
- self._invalidate_cache_and_stream(
- txn, self.get_destination_retry_timings, (destination,)
- )
-
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 9b3f2f4b96..ec7b2c9672 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -374,29 +374,25 @@ class ReadWriteLock(object):
defer.returnValue(_ctx_manager())
-class DeferredTimeoutError(Exception):
- """
- This error is raised by default when a L{Deferred} times out.
- """
-
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise defer.TimeoutError(timeout, "Deferred")
+ return value
-def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
- """
- Add a timeout to a deferred by scheduling it to be cancelled after
- timeout seconds.
- This is essentially a backport of deferred.addTimeout, which was introduced
- in twisted 16.5.
+def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+ """The in built twisted `Deferred.addTimeout` fails to time out deferreds
+ that have a canceller that throws exceptions. This method creates a new
+ deferred that wraps and times out the given deferred, correctly handling
+ the case where the given deferred's canceller throws.
- If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
- unless a cancelable function was passed to its initialization or unless
- a different on_timeout_cancel callable is provided.
+ NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
Args:
- deferred (defer.Deferred): deferred to be timed out
- timeout (Number): seconds to time out after
- reactor (twisted.internet.reactor): the Twisted reactor to use
-
+ deferred (Deferred)
+ timeout (float): Timeout in seconds
+ reactor (twisted.internet.reactor): The twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately
after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout.
@@ -406,13 +402,26 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
the timeout.
The default callable (if none is provided) will translate a
- CancelledError Failure into a DeferredTimeoutError.
+ CancelledError Failure into a defer.TimeoutError.
+
+ Returns:
+ Deferred
"""
+
+ new_d = defer.Deferred()
+
timed_out = [False]
def time_it_out():
timed_out[0] = True
- deferred.cancel()
+
+ try:
+ deferred.cancel()
+ except: # noqa: E722, if we throw any exception it'll break time outs
+ logger.exception("Canceller failed during timeout")
+
+ if not new_d.called:
+ new_d.errback(defer.TimeoutError(timeout, "Deferred"))
delayed_call = reactor.callLater(timeout, time_it_out)
@@ -432,9 +441,14 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addBoth(cancel_timeout)
+ def success_cb(val):
+ if not new_d.called:
+ new_d.callback(val)
-def _cancelled_to_timed_out_error(value, timeout):
- if isinstance(value, failure.Failure):
- value.trap(CancelledError)
- raise DeferredTimeoutError(timeout, "Deferred")
- return value
+ def failure_cb(val):
+ if not new_d.called:
+ new_d.errback(val)
+
+ deferred.addCallbacks(success_cb, failure_cb)
+
+ return new_d
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 7b065b195e..f37d5bec08 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
import six
@@ -20,6 +21,8 @@ from six.moves import intern
from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
+logger = logging.getLogger(__name__)
+
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
@@ -76,16 +79,20 @@ def register_cache(cache_type, cache_name, cache):
return []
def collect(self):
- if cache_type == "response_cache":
- response_cache_size.labels(cache_name).set(len(cache))
- response_cache_hits.labels(cache_name).set(self.hits)
- response_cache_evicted.labels(cache_name).set(self.evicted_size)
- response_cache_total.labels(cache_name).set(self.hits + self.misses)
- else:
- cache_size.labels(cache_name).set(len(cache))
- cache_hits.labels(cache_name).set(self.hits)
- cache_evicted.labels(cache_name).set(self.evicted_size)
- cache_total.labels(cache_name).set(self.hits + self.misses)
+ try:
+ if cache_type == "response_cache":
+ response_cache_size.labels(cache_name).set(len(cache))
+ response_cache_hits.labels(cache_name).set(self.hits)
+ response_cache_evicted.labels(cache_name).set(self.evicted_size)
+ response_cache_total.labels(cache_name).set(self.hits + self.misses)
+ else:
+ cache_size.labels(cache_name).set(len(cache))
+ cache_hits.labels(cache_name).set(self.hits)
+ cache_evicted.labels(cache_name).set(self.evicted_size)
+ cache_total.labels(cache_name).set(self.hits + self.misses)
+ except Exception as e:
+ logger.warn("Error calculating metrics for %s: %s", cache_name, e)
+ raise
yield GaugeMetricFamily("__unused", "")
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ce85b2ae11..f369780277 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -16,12 +16,17 @@
import logging
from collections import OrderedDict
+from six import iteritems, itervalues
+
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
+SENTINEL = object()
+
+
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False, iterable=False):
@@ -54,11 +59,8 @@ class ExpiringCache(object):
self.iterable = iterable
- self._size_estimate = 0
-
self.metrics = register_cache("expiring", cache_name, self)
- def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
return
@@ -75,16 +77,11 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
- if self.iterable:
- self._size_estimate += len(value)
-
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- removed_len = len(value.value)
- self.metrics.inc_evictions(removed_len)
- self._size_estimate -= removed_len
+ self.metrics.inc_evictions(len(value.value))
else:
self.metrics.inc_evictions()
@@ -101,6 +98,21 @@ class ExpiringCache(object):
return entry.value
+ def pop(self, key, default=SENTINEL):
+ """Removes and returns the value with the given key from the cache.
+
+ If the key isn't in the cache then `default` will be returned if
+ specified, otherwise `KeyError` will get raised.
+
+ Identical functionality to `dict.pop(..)`.
+ """
+
+ value = self._cache.pop(key, default)
+ if value is SENTINEL:
+ raise KeyError(key)
+
+ return value
+
def __contains__(self, key):
return key in self._cache
@@ -128,14 +140,16 @@ class ExpiringCache(object):
keys_to_delete = set()
- for key, cache_entry in self._cache.items():
+ for key, cache_entry in iteritems(self._cache):
if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key)
for k in keys_to_delete:
value = self._cache.pop(k)
if self.iterable:
- self._size_estimate -= len(value.value)
+ self.metrics.inc_evictions(len(value.value))
+ else:
+ self.metrics.inc_evictions()
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
@@ -144,12 +158,14 @@ class ExpiringCache(object):
def __len__(self):
if self.iterable:
- return self._size_estimate
+ return sum(len(entry.value) for entry in itervalues(self._cache))
else:
return len(self._cache)
class _CacheEntry(object):
+ __slots__ = ["time", "value"]
+
def __init__(self, time, value):
self.time = time
self.value = value
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a0c2d37610..89224b26cc 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -200,7 +200,7 @@ class LoggingContext(object):
sentinel = Sentinel()
- def __init__(self, name=None, parent_context=None):
+ def __init__(self, name=None, parent_context=None, request=None):
self.previous_context = LoggingContext.current_context()
self.name = name
@@ -218,6 +218,13 @@ class LoggingContext(object):
self.parent_context = parent_context
+ if self.parent_context is not None:
+ self.parent_context.copy_to(self)
+
+ if request is not None:
+ # the request param overrides the request from the parent context
+ self.request = request
+
def __str__(self):
return "%s@%x" % (self.name, id(self))
@@ -256,9 +263,6 @@ class LoggingContext(object):
)
self.alive = True
- if self.parent_context is not None:
- self.parent_context.copy_to(self)
-
return self
def __exit__(self, type, value, traceback):
@@ -439,6 +443,35 @@ class PreserveLoggingContext(object):
)
+def nested_logging_context(suffix, parent_context=None):
+ """Creates a new logging context as a child of another.
+
+ The nested logging context will have a 'request' made up of the parent context's
+ request, plus the given suffix.
+
+ CPU/db usage stats will be added to the parent context's on exit.
+
+ Normal usage looks like:
+
+ with nested_logging_context(suffix):
+ # ... do stuff
+
+ Args:
+ suffix (str): suffix to add to the parent context's 'request'.
+ parent_context (LoggingContext|None): parent context. Will use the current context
+ if None.
+
+ Returns:
+ LoggingContext: new logging context.
+ """
+ if parent_context is None:
+ parent_context = LoggingContext.current_context()
+ return LoggingContext(
+ parent_context=parent_context,
+ request=parent_context.request + "-" + suffix,
+ )
+
+
def preserve_fn(f):
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 14be3c7396..8d0f2a8918 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -19,22 +19,40 @@ from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal
PUBLIC_KEY = (
- "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
- "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS"
- "kbh/C+BR3utDS555mV"
+ "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
+ "XALqeK+7385NlLja3DE/DO9mGhnd9+bAy39EKT3sTV6+WXQ4yD0TvEEyUEMtjWkSEm6U32+C"
+ "DaS3TW/vPBUMeJQwq+Ydcif1UlnpXrDDTamD0AU9VaEvHq+3HAkipqn0TGpKON6aqk4vauDx"
+ "oXSsV5TXBVrxP/y7HpMOpU4GUWsaaacBTKKNnUaQB4UflvydaPJUuwdaCUJGTMjbhWrjVfK+"
+ "jslseSPxU6XvrkZMyCr4znxvuDxjMk1RGIdO7v+rbBMLEgqtSMNqJbYeVCnj2CFgc3fcTcld"
+ "X2uOJDrJb/WRlHulthCh"
)
PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
-MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
-4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
-vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
-Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
-xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
-PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
-gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
-DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
-pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
-EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
+MIIEpQIBAAKCAQEAx4RgE2luCoRNt/u56x+Ixcd8i6vTo2hLOVwC6nivu9/OTZS4
+2twxPwzvZhoZ3ffmwMt/RCk97E1evll0OMg9E7xBMlBDLY1pEhJulN9vgg2kt01v
+7zwVDHiUMKvmHXIn9VJZ6V6ww02pg9AFPVWhLx6vtxwJIqap9ExqSjjemqpOL2rg
+8aF0rFeU1wVa8T/8ux6TDqVOBlFrGmmnAUyijZ1GkAeFH5b8nWjyVLsHWglCRkzI
+24Vq41Xyvo7JbHkj8VOl765GTMgq+M58b7g8YzJNURiHTu7/q2wTCxIKrUjDaiW2
+HlQp49ghYHN33E3JXV9rjiQ6yW/1kZR7pbYQoQIDAQABAoIBAQC8KJ0q8Wzzwh5B
+esa1dQHZ8+4DEsL/Amae66VcVwD0X3cCN1W2IZ7X5W0Ij2kBqr8V51RYhcR+S+Ek
+BtzSiBUBvbKGrqcMGKaUgomDIMzai99hd0gvCCyZnEW1OQhFkNkaRNXCfqiZJ27M
+fqvSUiU2eOwh9fCvmxoA6Of8o3FbzcJ+1GMcobWRllDtLmj6lgVbDzuA+0jC5daB
+9Tj1pBzu3wn3ufxiS+gBnJ+7NcXH3E73lqCcPa2ufbZ1haxfiGCnRIhFXuQDgxFX
+vKdEfDgtvas6r1ahGbc+b/q8E8fZT7cABuIU4yfOORK+MhpyWbvoyyzuVGKj3PKt
+KSPJu5CZAoGBAOkoJfAVyYteqKcmGTanGqQnAY43CaYf6GdSPX/jg+JmKZg0zqMC
+jWZUtPb93i+jnOInbrnuHOiHAxI8wmhEPed28H2lC/LU8PzlqFkZXKFZ4vLOhhRB
+/HeHCFIDosPFlohWi3b+GAjD7sXgnIuGmnXWe2ea/TS3yersifDEoKKjAoGBANsQ
+gJX2cJv1c3jhdgcs8vAt5zIOKcCLTOr/QPmVf/kxjNgndswcKHwsxE/voTO9q+TF
+v/6yCSTxAdjuKz1oIYWgi/dZo82bBKWxNRpgrGviU3/zwxiHlyIXUhzQu78q3VS/
+7S1XVbc7qMV++XkYKHPVD+nVG/gGzFxumX7MLXfrAoGBAJit9cn2OnjNj9uFE1W6
+r7N254ndeLAUjPe73xH0RtTm2a4WRopwjW/JYIetTuYbWgyujc+robqTTuuOZjAp
+H/CG7o0Ym251CypQqaFO/l2aowclPp/dZhpPjp9GSjuxFBZLtiBB3DNBOwbRQzIK
+/vLTdRQvZkgzYkI4i0vjNt3JAoGBANP8HSKBLymMlShlrSx2b8TB9tc2Y2riohVJ
+2ttqs0M2kt/dGJWdrgOz4mikL+983Olt/0P9juHDoxEEMK2kpcPEv40lnmBpYU7h
+s8yJvnBLvJe2EJYdJ8AipyAhUX1FgpbvfxmASP8eaUxsegeXvBWTGWojAoS6N2o+
+0KSl+l3vAoGAFqm0gO9f/Q1Se60YQd4l2PZeMnJFv0slpgHHUwegmd6wJhOD7zJ1
+CkZcXwiv7Nog7AI9qKJEUXLjoqL+vJskBzSOqU3tcd670YQMi1aXSXJqYE202K7o
+EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----"""
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 97f1267380..4b4ac5f6c7 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -20,6 +20,7 @@ from prometheus_client import Counter
from twisted.internet import defer
+from synapse.metrics import InFlightGauge
from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
@@ -45,6 +46,13 @@ block_db_txn_duration = Counter(
block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
+# Tracks the number of blocks currently active
+in_flight = InFlightGauge(
+ "synapse_util_metrics_block_in_flight", "",
+ labels=["block_name"],
+ sub_metrics=["real_time_max", "real_time_sum"],
+)
+
def measure_func(name):
def wrapper(func):
@@ -82,10 +90,14 @@ class Measure(object):
self.start_usage = self.start_context.get_resource_usage()
+ in_flight.register((self.name,), self._update_in_flight)
+
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
+ in_flight.unregister((self.name,), self._update_in_flight)
+
duration = self.clock.time() - self.start
block_counter.labels(self.name).inc()
@@ -120,3 +132,13 @@ class Measure(object):
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)
+
+ def _update_in_flight(self, metrics):
+ """Gets called when processing in flight metrics
+ """
+ duration = self.clock.time() - self.start
+
+ metrics.real_time_max = max(metrics.real_time_max, duration)
+ metrics.real_time_sum += duration
+
+ # TODO: Add other in flight metrics.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 8a3a06fd74..26cce7d197 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -188,7 +188,7 @@ class RetryDestinationLimiter(object):
else:
self.retry_interval = self.min_retry_interval
- logger.debug(
+ logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval
)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index d4680863d3..c64ad2144c 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -324,14 +324,13 @@ def filter_events_for_server(store, server_name, events):
# server's domain.
#
# event_to_state_ids contains lots of duplicates, so it turns out to be
- # cheaper to build a complete set of unique
- # ((type, state_key), event_id) tuples, and then filter out the ones we
- # don't want.
+ # cheaper to build a complete event_id => (type, state_key) dict, and then
+ # filter out the ones we don't want
#
- state_key_to_event_id_set = {
- e
+ event_id_to_state_key = {
+ event_id: key
for key_to_eid in itervalues(event_to_state_ids)
- for e in key_to_eid.items()
+ for key, event_id in iteritems(key_to_eid)
}
def include(typ, state_key):
@@ -346,7 +345,7 @@ def filter_events_for_server(store, server_name, events):
event_map = yield store.get_events([
e_id
- for key, e_id in state_key_to_event_id_set
+ for e_id, key in iteritems(event_id_to_state_key)
if include(key[0], key[1])
])
|