diff options
Diffstat (limited to 'synapse')
139 files changed, 5839 insertions, 2461 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 9dbe0b9f10..89ea9a9775 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -27,4 +27,4 @@ try: except ImportError: pass -__version__ = "0.33.4" +__version__ = "0.33.8" diff --git a/synapse/_scripts/__init__.py b/synapse/_scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/_scripts/__init__.py diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py new file mode 100644 index 0000000000..70cecde486 --- /dev/null +++ b/synapse/_scripts/register_new_matrix_user.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import getpass +import hashlib +import hmac +import logging +import sys + +from six.moves import input + +import requests as _requests +import yaml + + +def request_registration( + user, + password, + server_location, + shared_secret, + admin=False, + requests=_requests, + _print=print, + exit=sys.exit, +): + + url = "%s/_matrix/client/r0/admin/register" % (server_location,) + + # Get the nonce + r = requests.get(url, verify=False) + + if r.status_code is not 200: + _print("ERROR! Received %d %s" % (r.status_code, r.reason)) + if 400 <= r.status_code < 500: + try: + _print(r.json()["error"]) + except Exception: + pass + return exit(1) + + nonce = r.json()["nonce"] + + mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) + + mac.update(nonce.encode('utf8')) + mac.update(b"\x00") + mac.update(user.encode('utf8')) + mac.update(b"\x00") + mac.update(password.encode('utf8')) + mac.update(b"\x00") + mac.update(b"admin" if admin else b"notadmin") + + mac = mac.hexdigest() + + data = { + "nonce": nonce, + "username": user, + "password": password, + "mac": mac, + "admin": admin, + } + + _print("Sending registration request...") + r = requests.post(url, json=data, verify=False) + + if r.status_code is not 200: + _print("ERROR! Received %d %s" % (r.status_code, r.reason)) + if 400 <= r.status_code < 500: + try: + _print(r.json()["error"]) + except Exception: + pass + return exit(1) + + _print("Success!") + + +def register_new_user(user, password, server_location, shared_secret, admin): + if not user: + try: + default_user = getpass.getuser() + except Exception: + default_user = None + + if default_user: + user = input("New user localpart [%s]: " % (default_user,)) + if not user: + user = default_user + else: + user = input("New user localpart: ") + + if not user: + print("Invalid user name") + sys.exit(1) + + if not password: + password = getpass.getpass("Password: ") + + if not password: + print("Password cannot be blank.") + sys.exit(1) + + confirm_password = getpass.getpass("Confirm password: ") + + if password != confirm_password: + print("Passwords do not match") + sys.exit(1) + + if admin is None: + admin = input("Make admin [no]: ") + if admin in ("y", "yes", "true"): + admin = True + else: + admin = False + + request_registration(user, password, server_location, shared_secret, bool(admin)) + + +def main(): + + logging.captureWarnings(True) + + parser = argparse.ArgumentParser( + description="Used to register new users with a given home server when" + " registration has been disabled. The home server must be" + " configured with the 'registration_shared_secret' option" + " set." + ) + parser.add_argument( + "-u", + "--user", + default=None, + help="Local part of the new user. Will prompt if omitted.", + ) + parser.add_argument( + "-p", + "--password", + default=None, + help="New password for user. Will prompt if omitted.", + ) + admin_group = parser.add_mutually_exclusive_group() + admin_group.add_argument( + "-a", + "--admin", + action="store_true", + help=( + "Register new user as an admin. " + "Will prompt if --no-admin is not set either." + ), + ) + admin_group.add_argument( + "--no-admin", + action="store_true", + help=( + "Register new user as a regular user. " + "Will prompt if --admin is not set either." + ), + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "-c", + "--config", + type=argparse.FileType('r'), + help="Path to server config file. Used to read in shared secret.", + ) + + group.add_argument( + "-k", "--shared-secret", help="Shared secret as defined in server config file." + ) + + parser.add_argument( + "server_url", + default="https://localhost:8448", + nargs='?', + help="URL to use to talk to the home server. Defaults to " + " 'https://localhost:8448'.", + ) + + args = parser.parse_args() + + if "config" in args and args.config: + config = yaml.safe_load(args.config) + secret = config.get("registration_shared_secret", None) + if not secret: + print("No 'registration_shared_secret' defined in config.") + sys.exit(1) + else: + secret = args.shared_secret + + admin = None + if args.admin or args.no_admin: + admin = args.admin + + register_new_user(args.user, args.password, args.server_url, secret, admin) + + +if __name__ == "__main__": + main() diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c2630c4c64..f20e0fcf0b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -51,6 +51,7 @@ class LoginType(object): EMAIL_IDENTITY = u"m.login.email.identity" MSISDN = u"m.login.msisdn" RECAPTCHA = u"m.login.recaptcha" + TERMS = u"m.login.terms" DUMMY = u"m.login.dummy" # Only for C/S API v1 @@ -61,6 +62,7 @@ class LoginType(object): class EventTypes(object): Member = "m.room.member" Create = "m.room.create" + Tombstone = "m.room.tombstone" JoinRules = "m.room.join_rules" PowerLevels = "m.room.power_levels" Aliases = "m.room.aliases" @@ -101,6 +103,7 @@ class ThirdPartyEntityKind(object): class RoomVersions(object): V1 = "1" VDH_TEST = "vdh-test-version" + STATE_V2_TEST = "state-v2-test" # the version we will give rooms which are created on this server @@ -108,7 +111,11 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1 # vdh-test-version is a placeholder to get room versioning support working and tested # until we have a working v2. -KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST} +KNOWN_ROOM_VERSIONS = { + RoomVersions.V1, + RoomVersions.VDH_TEST, + RoomVersions.STATE_V2_TEST, +} ServerNoticeMsgType = "m.server_notice" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 2e7f98404d..48b903374d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -59,6 +59,7 @@ class Codes(object): RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED" UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" + WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" class CodeMessageException(RuntimeError): @@ -312,6 +313,20 @@ class LimitExceededError(SynapseError): ) +class RoomKeysVersionError(SynapseError): + """A client has tried to upload to a non-current version of the room_keys store + """ + def __init__(self, current_version): + """ + Args: + current_version (str): the current version of the store they should have used + """ + super(RoomKeysVersionError, self).__init__( + 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION + ) + self.current_version = current_version + + class IncompatibleRoomVersionError(SynapseError): """A server is trying to join a room whose version it does not support.""" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a31a9a17e0..677c0bdd4c 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -172,7 +172,10 @@ USER_FILTER_SCHEMA = { # events a lot easier as we can then use a negative lookbehind # assertion to split '\.' If we allowed \\ then it would # incorrectly split '\\.' See synapse.events.utils.serialize_event - "pattern": "^((?!\\\).)*$" + # + # Note that because this is a regular expression, we have to escape + # each backslash in the pattern. + "pattern": r"^((?!\\\\).)*$" } } }, @@ -226,7 +229,7 @@ class Filtering(object): jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()) except jsonschema.ValidationError as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) class FilterCollection(object): diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 71347912f1..f78695b657 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,6 @@ FEDERATION_PREFIX = "/_matrix/federation/v1" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" -SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" @@ -64,7 +63,7 @@ class ConsentURIBuilder(object): """ mac = hmac.new( key=self._hmac_secret, - msg=user_id, + msg=user_id.encode('ascii'), digestmod=sha256, ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 3b6b9368b8..c3afcc573b 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -24,7 +24,7 @@ try: python_dependencies.check_requirements() except python_dependencies.MissingRequirementError as e: message = "\n".join([ - "Missing Requirement: %s" % (e.message,), + "Missing Requirement: %s" % (str(e),), "To install run:", " pip install --upgrade --force \"%s\"" % (e.dependency,), "", diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 7c866e246a..18584226e9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -17,6 +17,7 @@ import gc import logging import sys +import psutil from daemonize import Daemonize from twisted.internet import error, reactor @@ -24,12 +25,6 @@ from twisted.internet import error, reactor from synapse.util import PreserveLoggingContext from synapse.util.rlimit import change_resource_limit -try: - import affinity -except Exception: - affinity = None - - logger = logging.getLogger(__name__) @@ -89,15 +84,20 @@ def start_reactor( with PreserveLoggingContext(): logger.info("Running") if cpu_affinity is not None: - if not affinity: - quit_with_error( - "Missing package 'affinity' required for cpu_affinity\n" - "option\n\n" - "Install by running:\n\n" - " pip install affinity\n\n" - ) - logger.info("Setting CPU affinity to %s" % cpu_affinity) - affinity.set_process_affinity_mask(0, cpu_affinity) + # Turn the bitmask into bits, reverse it so we go from 0 up + mask_to_bits = bin(cpu_affinity)[2:][::-1] + + cpus = [] + cpu_num = 0 + + for i in mask_to_bits: + if i == "1": + cpus.append(cpu_num) + cpu_num += 1 + + p = psutil.Process() + p.cpu_affinity(cpus) + change_resource_limit(soft_file_limit) if gc_thresholds: gc.set_threshold(*gc_thresholds) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 86b5067400..8559e141af 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -136,7 +136,7 @@ def start(config_options): "Synapse appservice", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.appservice" @@ -172,7 +172,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ce2b113dbb..76aed8c60a 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -153,7 +153,7 @@ def start(config_options): "Synapse client reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.client_reader" @@ -181,7 +181,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index f98e456ea0..e4a68715aa 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -169,7 +169,7 @@ def start(config_options): "Synapse event creator", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.event_creator" @@ -178,6 +178,9 @@ def start(config_options): setup_logging(config, use_worker_options=True) + # This should only be done on the user directory worker or the master + config.update_user_directory = False + events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -199,7 +202,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 60f5973505..228a297fb8 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -140,7 +140,7 @@ def start(config_options): "Synapse federation reader", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_reader" @@ -168,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 60dd09aac3..e9a99d76e1 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -160,7 +160,7 @@ def start(config_options): "Synapse federation sender", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.federation_sender" @@ -201,7 +201,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) _base.start_worker_reactor("synapse-federation-sender", config) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 8c0b9c67b0..f5c61dec5b 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -68,7 +68,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet): "Authorization": auth_headers, } result = yield self.http_client.get_json( - self.main_uri + request.uri, + self.main_uri + request.uri.decode('ascii'), headers=headers, ) defer.returnValue((200, result)) @@ -125,7 +125,7 @@ class KeyUploadServlet(RestServlet): "Authorization": auth_headers, } result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri, + self.main_uri + request.uri.decode('ascii'), body, headers=headers, ) @@ -228,7 +228,7 @@ def start(config_options): "Synapse frontend proxy", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.frontend_proxy" @@ -258,7 +258,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index ac97e19649..415374a2ce 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -20,6 +20,7 @@ import sys from six import iteritems +import psutil from prometheus_client import Gauge from twisted.application import service @@ -36,7 +37,6 @@ from synapse.api.urls import ( FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, - SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, STATIC_PREFIX, WEB_CLIENT_PREFIX, @@ -58,7 +58,6 @@ from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirem from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource -from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer @@ -235,10 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources.update({ - SERVER_KEY_PREFIX: LocalKey(self), - SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), - }) + resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) @@ -301,7 +297,7 @@ class SynapseHomeServer(HomeServer): try: database_engine.check_database(db_conn.cursor()) except IncorrectDatabaseSetup as e: - quit_with_error(e.message) + quit_with_error(str(e)) # Gauges to expose monthly active user control metrics @@ -328,7 +324,7 @@ def setup(config_options): config_options, ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) if not config: @@ -384,10 +380,8 @@ def setup(config_options): def start(): hs.get_pusherpool().start() - hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_federation_client().start_get_pdu_cache() reactor.callWhenRunning(start) @@ -457,6 +451,10 @@ def run(hs): stats["homeserver"] = hs.config.server_name stats["timestamp"] = now stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) stats["total_users"] = yield hs.get_datastore().count_all_users() total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() @@ -500,7 +498,6 @@ def run(hs): def performance_stats_init(): try: - import psutil process = psutil.Process() # Ensure we can fetch both, and make the initial request for cpu_percent # so the next request will use this as the initial point. @@ -508,12 +505,9 @@ def run(hs): process.cpu_percent(interval=None) logger.info("report_stats can use psutil") stats_process.append(process) - except (ImportError, AttributeError): - logger.warn( - "report_stats enabled but psutil is not installed or incorrect version." - " Disabling reporting of memory/cpu stats." - " Ensuring psutil is available will help matrix.org track performance" - " changes across releases." + except (AttributeError): + logger.warning( + "Unable to read memory/cpu stats. Disabling reporting." ) def generate_user_daily_visit_stats(): @@ -528,10 +522,13 @@ def run(hs): clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) # monthly active user limiting functionality - clock.looping_call( - hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 - ) - hs.get_datastore().reap_monthly_active_users() + def reap_monthly_active_users(): + return run_as_background_process( + "reap_monthly_active_users", + hs.get_datastore().reap_monthly_active_users, + ) + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) + reap_monthly_active_users() @defer.inlineCallbacks def generate_monthly_active_users(): @@ -545,12 +542,15 @@ def run(hs): registered_reserved_users_mau_gauge.set(float(reserved_count)) max_mau_gauge.set(float(hs.config.max_mau_value)) - hs.get_datastore().initialise_reserved_users( - hs.config.mau_limits_reserved_threepids - ) - generate_monthly_active_users() + def start_generate_monthly_active_users(): + return run_as_background_process( + "generate_monthly_active_users", + generate_monthly_active_users, + ) + + start_generate_monthly_active_users() if hs.config.limit_usage_by_mau: - clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings if hs.config.report_stats: @@ -566,7 +566,7 @@ def run(hs): clock.call_later(5 * 60, start_phone_stats_home) if hs.config.daemonize and hs.config.print_pidfile: - print (hs.config.pid_file) + print(hs.config.pid_file) _base.start_reactor( "synapse-homeserver", diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index e3dbb3b4e6..acc0487adc 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -133,7 +133,7 @@ def start(config_options): "Synapse media repository", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.media_repository" @@ -168,7 +168,6 @@ def start(config_options): ss.start_listening(config.worker_listeners) def start(): - ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() reactor.callWhenRunning(start) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 244c604de9..83b0863f00 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -28,6 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.replication.slave.storage._base import __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.pushers import SlavedPusherStore @@ -49,31 +50,31 @@ class PusherSlaveStore( SlavedAccountDataStore ): update_pusher_last_stream_ordering_and_success = ( - DataStore.update_pusher_last_stream_ordering_and_success.__func__ + __func__(DataStore.update_pusher_last_stream_ordering_and_success) ) update_pusher_failing_since = ( - DataStore.update_pusher_failing_since.__func__ + __func__(DataStore.update_pusher_failing_since) ) update_pusher_last_stream_ordering = ( - DataStore.update_pusher_last_stream_ordering.__func__ + __func__(DataStore.update_pusher_last_stream_ordering) ) get_throttle_params_by_room = ( - DataStore.get_throttle_params_by_room.__func__ + __func__(DataStore.get_throttle_params_by_room) ) set_throttle_params = ( - DataStore.set_throttle_params.__func__ + __func__(DataStore.set_throttle_params) ) get_time_of_last_push_action_before = ( - DataStore.get_time_of_last_push_action_before.__func__ + __func__(DataStore.get_time_of_last_push_action_before) ) get_profile_displayname = ( - DataStore.get_profile_displayname.__func__ + __func__(DataStore.get_profile_displayname) ) @@ -160,11 +161,11 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - self.pusher_pool.on_new_notifications( + yield self.pusher_pool.on_new_notifications( token, token, ) elif stream_name == "receipts": - self.pusher_pool.on_new_receipts( + yield self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) ) except Exception: @@ -182,7 +183,7 @@ class PusherReplicationHandler(ReplicationClientHandler): def start_pusher(self, user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) + return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) def start(config_options): @@ -191,7 +192,7 @@ def start(config_options): "Synapse pusher", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.pusher" @@ -228,7 +229,6 @@ def start(config_options): def start(): ps.get_pusherpool().start() ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 6662340797..0354e82bf8 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -33,7 +33,7 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource -from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -147,7 +147,7 @@ class SynchrotronPresence(object): and haven't come back yet. If there are poke the master about them. """ now = self.clock.time_msec() - for user_id, last_sync_ms in self.users_going_offline.items(): + for user_id, last_sync_ms in list(self.users_going_offline.items()): if now - last_sync_ms > 10 * 1000: self.users_going_offline.pop(user_id, None) self.send_user_sync(user_id, False, last_sync_ms) @@ -156,9 +156,9 @@ class SynchrotronPresence(object): # TODO Hows this supposed to work? pass - get_states = PresenceHandler.get_states.__func__ - get_state = PresenceHandler.get_state.__func__ - current_state_for_users = PresenceHandler.current_state_for_users.__func__ + get_states = __func__(PresenceHandler.get_states) + get_state = __func__(PresenceHandler.get_state) + current_state_for_users = __func__(PresenceHandler.current_state_for_users) def user_syncing(self, user_id, affect_presence): if affect_presence: @@ -208,7 +208,7 @@ class SynchrotronPresence(object): ) for row in rows] for state in states: - self.user_to_current_state[row.user_id] = state + self.user_to_current_state[state.user_id] = state stream_id = token yield self.notify_from_replication(states, stream_id) @@ -226,7 +226,15 @@ class SynchrotronPresence(object): class SynchrotronTyping(object): def __init__(self, hs): self._latest_room_serial = 0 + self._reset() + + def _reset(self): + """ + Reset the typing handler's data caches. + """ + # map room IDs to serial numbers self._room_serials = {} + # map room IDs to sets of users currently typing self._room_typing = {} def stream_positions(self): @@ -236,6 +244,12 @@ class SynchrotronTyping(object): return {"typing": self._latest_room_serial} def process_replication_rows(self, token, rows): + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. self._latest_room_serial = token for row in rows: @@ -410,7 +424,7 @@ def start(config_options): "Synapse synchrotron", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.synchrotron" @@ -435,7 +449,6 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() - ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py deleted file mode 100755 index d658f967ba..0000000000 --- a/synapse/app/synctl.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import collections -import errno -import glob -import os -import os.path -import signal -import subprocess -import sys -import time - -from six import iteritems - -import yaml - -SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] - -GREEN = "\x1b[1;32m" -YELLOW = "\x1b[1;33m" -RED = "\x1b[1;31m" -NORMAL = "\x1b[m" - - -def pid_running(pid): - try: - os.kill(pid, 0) - return True - except OSError as err: - if err.errno == errno.EPERM: - return True - return False - - -def write(message, colour=NORMAL, stream=sys.stdout): - if colour == NORMAL: - stream.write(message + "\n") - else: - stream.write(colour + message + NORMAL + "\n") - - -def abort(message, colour=RED, stream=sys.stderr): - write(message, colour, stream) - sys.exit(1) - - -def start(configfile): - write("Starting ...") - args = SYNAPSE - args.extend(["--daemonize", "-c", configfile]) - - try: - subprocess.check_call(args) - write("started synapse.app.homeserver(%r)" % - (configfile,), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting (exit code: %d); see above for logs" % e.returncode, - colour=RED, - ) - - -def start_worker(app, configfile, worker_configfile): - args = [ - "python", "-B", - "-m", app, - "-c", configfile, - "-c", worker_configfile - ] - - try: - subprocess.check_call(args) - write("started %s(%r)" % (app, worker_configfile), colour=GREEN) - except subprocess.CalledProcessError as e: - write( - "error starting %s(%r) (exit code: %d); see above for logs" % ( - app, worker_configfile, e.returncode, - ), - colour=RED, - ) - - -def stop(pidfile, app): - if os.path.exists(pidfile): - pid = int(open(pidfile).read()) - try: - os.kill(pid, signal.SIGTERM) - write("stopped %s" % (app,), colour=GREEN) - except OSError as err: - if err.errno == errno.ESRCH: - write("%s not running" % (app,), colour=YELLOW) - elif err.errno == errno.EPERM: - abort("Cannot stop %s: Operation not permitted" % (app,)) - else: - abort("Cannot stop %s: Unknown error" % (app,)) - - -Worker = collections.namedtuple("Worker", [ - "app", "configfile", "pidfile", "cache_factor" -]) - - -def main(): - - parser = argparse.ArgumentParser() - - parser.add_argument( - "action", - choices=["start", "stop", "restart"], - help="whether to start, stop or restart the synapse", - ) - parser.add_argument( - "configfile", - nargs="?", - default="homeserver.yaml", - help="the homeserver config file, defaults to homeserver.yaml", - ) - parser.add_argument( - "-w", "--worker", - metavar="WORKERCONFIG", - help="start or stop a single worker", - ) - parser.add_argument( - "-a", "--all-processes", - metavar="WORKERCONFIGDIR", - help="start or stop all the workers in the given directory" - " and the main synapse process", - ) - - options = parser.parse_args() - - if options.worker and options.all_processes: - write( - 'Cannot use "--worker" with "--all-processes"', - stream=sys.stderr - ) - sys.exit(1) - - configfile = options.configfile - - if not os.path.exists(configfile): - write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name=<server name>'\n" % ( - " ".join(SYNAPSE), options.configfile - ), - stream=sys.stderr, - ) - sys.exit(1) - - with open(configfile) as stream: - config = yaml.load(stream) - - pidfile = config["pid_file"] - cache_factor = config.get("synctl_cache_factor") - start_stop_synapse = True - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - - cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): - os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) - - worker_configfiles = [] - if options.worker: - start_stop_synapse = False - worker_configfile = options.worker - if not os.path.exists(worker_configfile): - write( - "No worker config found at %r" % (worker_configfile,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.append(worker_configfile) - - if options.all_processes: - # To start the main synapse with -a you need to add a worker file - # with worker_app == "synapse.app.homeserver" - start_stop_synapse = False - worker_configdir = options.all_processes - if not os.path.isdir(worker_configdir): - write( - "No worker config directory found at %r" % (worker_configdir,), - stream=sys.stderr, - ) - sys.exit(1) - worker_configfiles.extend(sorted(glob.glob( - os.path.join(worker_configdir, "*.yaml") - ))) - - workers = [] - for worker_configfile in worker_configfiles: - with open(worker_configfile) as stream: - worker_config = yaml.load(stream) - worker_app = worker_config["worker_app"] - if worker_app == "synapse.app.homeserver": - # We need to special case all of this to pick up options that may - # be set in the main config file or in this worker config file. - worker_pidfile = ( - worker_config.get("pid_file") - or pidfile - ) - worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor - daemonize = worker_config.get("daemonize") or config.get("daemonize") - assert daemonize, "Main process must have daemonize set to true" - - # The master process doesn't support using worker_* config. - for key in worker_config: - if key == "worker_app": # But we allow worker_app - continue - assert not key.startswith("worker_"), \ - "Main process cannot use worker_* config" - else: - worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, "worker_daemonize") - worker_cache_factor = worker_config.get("synctl_cache_factor") - workers.append(Worker( - worker_app, worker_configfile, worker_pidfile, worker_cache_factor, - )) - - action = options.action - - if action == "stop" or action == "restart": - for worker in workers: - stop(worker.pidfile, worker.app) - - if start_stop_synapse: - stop(pidfile, "synapse.app.homeserver") - - # Wait for synapse to actually shutdown before starting it again - if action == "restart": - running_pids = [] - if start_stop_synapse and os.path.exists(pidfile): - running_pids.append(int(open(pidfile).read())) - for worker in workers: - if os.path.exists(worker.pidfile): - running_pids.append(int(open(worker.pidfile).read())) - if len(running_pids) > 0: - write("Waiting for process to exit before restarting...") - for running_pid in running_pids: - while pid_running(running_pid): - time.sleep(0.2) - write("All processes exited; now restarting...") - - if action == "start" or action == "restart": - if start_stop_synapse: - # Check if synapse is already running - if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): - abort("synapse.app.homeserver already running") - start(configfile) - - for worker in workers: - if worker.cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - - start_worker(worker.app, configfile, worker.configfile) - - if cache_factor: - os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) - else: - os.environ.pop("SYNAPSE_CACHE_FACTOR", None) - - -if __name__ == "__main__": - main() diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 96ffcaf073..0a5f62b509 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -188,7 +188,7 @@ def start(config_options): "Synapse user directory", config_options ) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) assert config.worker_app == "synapse.app.user_dir" @@ -229,7 +229,6 @@ def start(config_options): def start(): ps.get_datastore().start_profiling() - ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index 58c97a70af..79fe9c3dac 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -25,10 +25,10 @@ if __name__ == "__main__": try: config = HomeServerConfig.load_config("", sys.argv[3:]) except ConfigError as e: - sys.stderr.write("\n" + e.message + "\n") + sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) - print (getattr(config, key)) + print(getattr(config, key)) sys.exit(0) else: sys.stderr.write("Unknown command %r\n" % (action,)) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 3d2e90dd5b..14dae65ea0 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -106,10 +106,7 @@ class Config(object): @classmethod def check_file(cls, file_path, config_name): if file_path is None: - raise ConfigError( - "Missing config for %s." - % (config_name,) - ) + raise ConfigError("Missing config for %s." % (config_name,)) try: os.stat(file_path) except OSError as e: @@ -128,9 +125,7 @@ class Config(object): if e.errno != errno.EEXIST: raise if not os.path.isdir(dir_path): - raise ConfigError( - "%s is not a directory" % (dir_path,) - ) + raise ConfigError("%s is not a directory" % (dir_path,)) return dir_path @classmethod @@ -156,21 +151,20 @@ class Config(object): return results def generate_config( - self, - config_dir_path, - server_name, - is_generating_file, - report_stats=None, + self, config_dir_path, server_name, is_generating_file, report_stats=None ): default_config = "# vim:ft=yaml\n" - default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", - config_dir_path=config_dir_path, - server_name=server_name, - is_generating_file=is_generating_file, - report_stats=report_stats, - )) + default_config += "\n\n".join( + dedent(conf) + for conf in self.invoke_all( + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + is_generating_file=is_generating_file, + report_stats=report_stats, + ) + ) config = yaml.load(default_config) @@ -178,23 +172,22 @@ class Config(object): @classmethod def load_config(cls, description, argv): - config_parser = argparse.ArgumentParser( - description=description, - ) + config_parser = argparse.ArgumentParser(description=description) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Where files such as certs and signing keys are stored when" - " their location is given explicitly in the config." - " Defaults to the directory containing the last config file", + " their location is given explicitly in the config." + " Defaults to the directory containing the last config file", ) config_args = config_parser.parse_args(argv) @@ -203,9 +196,7 @@ class Config(object): obj = cls() obj.read_config_files( - config_files, - keys_directory=config_args.keys_directory, - generate_keys=False, + config_files, keys_directory=config_args.keys_directory, generate_keys=False ) return obj @@ -213,38 +204,38 @@ class Config(object): def load_or_generate_config(cls, description, argv): config_parser = argparse.ArgumentParser(add_help=False) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--generate-config", action="store_true", - help="Generate a config file for the server name" + help="Generate a config file for the server name", ) config_parser.add_argument( "--report-stats", action="store", help="Whether the generated config reports anonymized usage statistics", - choices=["yes", "no"] + choices=["yes", "no"], ) config_parser.add_argument( "--generate-keys", action="store_true", - help="Generate any missing key files then exit" + help="Generate any missing key files then exit", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Used with 'generate-*' options to specify where files such as" - " certs and signing keys should be stored in, unless explicitly" - " specified in the config." + " certs and signing keys should be stored in, unless explicitly" + " specified in the config.", ) config_parser.add_argument( - "-H", "--server-name", - help="The server name to generate a config file for" + "-H", "--server-name", help="The server name to generate a config file for" ) config_args, remaining_args = config_parser.parse_known_args(argv) @@ -257,8 +248,8 @@ class Config(object): if config_args.generate_config: if config_args.report_stats is None: config_parser.error( - "Please specify either --report-stats=yes or --report-stats=no\n\n" + - MISSING_REPORT_STATS_SPIEL + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + MISSING_REPORT_STATS_SPIEL ) if not config_files: config_parser.error( @@ -287,26 +278,32 @@ class Config(object): config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), - is_generating_file=True + is_generating_file=True, ) obj.invoke_all("generate_files", config) config_file.write(config_str) - print(( - "A config file has been generated in %r for server name" - " %r with corresponding SSL keys and self-signed" - " certificates. Please review this file and customise it" - " to your needs." - ) % (config_path, server_name)) + print( + ( + "A config file has been generated in %r for server name" + " %r with corresponding SSL keys and self-signed" + " certificates. Please review this file and customise it" + " to your needs." + ) + % (config_path, server_name) + ) print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print(( - "Config file %r already exists. Generating any missing key" - " files." - ) % (config_path,)) + print( + ( + "Config file %r already exists. Generating any missing key" + " files." + ) + % (config_path,) + ) generate_keys = True parser = argparse.ArgumentParser( @@ -338,8 +335,7 @@ class Config(object): return obj - def read_config_files(self, config_files, keys_directory=None, - generate_keys=False): + def read_config_files(self, config_files, keys_directory=None, generate_keys=False): if not keys_directory: keys_directory = os.path.dirname(config_files[-1]) @@ -364,8 +360,9 @@ class Config(object): if "report_stats" not in config: raise ConfigError( - MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + - MISSING_REPORT_STATS_SPIEL + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + + "\n" + + MISSING_REPORT_STATS_SPIEL ) if generate_keys: @@ -399,16 +396,16 @@ def find_config_files(search_paths): for entry in os.listdir(config_path): entry_path = os.path.join(config_path, entry) if not os.path.isfile(entry_path): - print ( - "Found subdirectory in config directory: %r. IGNORING." - ) % (entry_path, ) + err = "Found subdirectory in config directory: %r. IGNORING." + print(err % (entry_path,)) continue if not entry.endswith(".yaml"): - print ( - "Found file in config directory that does not" - " end in '.yaml': %r. IGNORING." - ) % (entry_path, ) + err = ( + "Found file in config directory that does not end in " + "'.yaml': %r. IGNORING." + ) + print(err % (entry_path,)) continue files.append(entry_path) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index fe156b6930..93d70cff14 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + # This file can't be called email.py because if it is, we cannot: import email.utils +import logging +import os + +import pkg_resources -from ._base import Config +from ._base import Config, ConfigError + +logger = logging.getLogger(__name__) class EmailConfig(Config): @@ -38,7 +46,6 @@ class EmailConfig(Config): "smtp_host", "smtp_port", "notif_from", - "template_dir", "notif_template_html", "notif_template_text", ] @@ -62,9 +69,26 @@ class EmailConfig(Config): self.email_smtp_host = email_config["smtp_host"] self.email_smtp_port = email_config["smtp_port"] self.email_notif_from = email_config["notif_from"] - self.email_template_dir = email_config["template_dir"] self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_text = email_config["notif_template_text"] + + template_dir = email_config.get("template_dir") + # we need an absolute path, because we change directory after starting (and + # we don't yet know what auxilliary templates like mail.css we will need). + # (Note that loading as package_resources with jinja.PackageLoader doesn't + # work for the same reason.) + if not template_dir: + template_dir = pkg_resources.resource_filename( + 'synapse', 'res/templates' + ) + template_dir = os.path.abspath(template_dir) + + for f in self.email_notif_template_text, self.email_notif_template_html: + p = os.path.join(template_dir, f) + if not os.path.isfile(p): + raise ConfigError("Unable to find email template file %s" % (p, )) + self.email_template_dir = template_dir + self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) @@ -113,7 +137,9 @@ class EmailConfig(Config): # require_transport_security: False # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" # app_name: Matrix - # template_dir: res/templates + # # if template_dir is unset, uses the example templates that are part of + # # the Synapse distribution. + # #template_dir: res/templates # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt # notif_for_new_users: True diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b8d5690f2b..10dd40159f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .push import PushConfig from .ratelimiting import RatelimitConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .room_directory import RoomDirectoryConfig from .saml2 import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig @@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig, SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, ConsentConfig, - ServerNoticesConfig, + ServerNoticesConfig, RoomDirectoryConfig, ): pass diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0fb964eb67..7480ed5145 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -15,10 +15,10 @@ from distutils.util import strtobool +from synapse.config._base import Config, ConfigError +from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols -from ._base import Config - class RegistrationConfig(Config): @@ -44,6 +44,10 @@ class RegistrationConfig(Config): ) self.auto_join_rooms = config.get("auto_join_rooms", []) + for room_alias in self.auto_join_rooms: + if not RoomAlias.is_valid(room_alias): + raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) + self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -98,6 +102,13 @@ class RegistrationConfig(Config): # to these rooms #auto_join_rooms: # - "#example:example.com" + + # Where auto_join_rooms are specified, setting this flag ensures that the + # the rooms exist by creating them when the first user on the + # homeserver registers. + # Setting to false means that if the rooms are not manually created, + # users cannot be auto-joined since they do not exist. + autocreate_auto_join_rooms: true """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fc909c1fac..06c62ab62c 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,7 +178,7 @@ class ContentRepositoryConfig(Config): def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") - return """ + return r""" # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py new file mode 100644 index 0000000000..9da13ab11b --- /dev/null +++ b/synapse/config/room_directory.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util import glob_to_regex + +from ._base import Config, ConfigError + + +class RoomDirectoryConfig(Config): + def read_config(self, config): + alias_creation_rules = config["alias_creation_rules"] + + self._alias_creation_rules = [ + _AliasRule(rule) + for rule in alias_creation_rules + ] + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # The `alias_creation` option controls who's allowed to create aliases + # on this server. + # + # The format of this option is a list of rules that contain globs that + # match against user_id and the new alias (fully qualified with server + # name). The action in the first rule that matches is taken, which can + # currently either be "allow" or "deny". + # + # If no rules match the request is denied. + alias_creation_rules: + - user_id: "*" + alias: "*" + action: allow + """ + + def is_alias_creation_allowed(self, user_id, alias): + """Checks if the given user is allowed to create the given alias + + Args: + user_id (str) + alias (str) + + Returns: + boolean: True if user is allowed to crate the alias + """ + for rule in self._alias_creation_rules: + if rule.matches(user_id, alias): + return rule.action == "allow" + + return False + + +class _AliasRule(object): + def __init__(self, rule): + action = rule["action"] + user_id = rule["user_id"] + alias = rule["alias"] + + if action in ("allow", "deny"): + self.action = action + else: + raise ConfigError( + "alias_creation_rules rules can only have action of 'allow'" + " or 'deny'" + ) + + try: + self._user_id_regex = glob_to_regex(user_id) + self._alias_regex = glob_to_regex(alias) + except Exception as e: + raise ConfigError("Failed to parse glob into regex: %s", e) + + def matches(self, user_id, alias): + """Tests if this rule matches the given user_id and alias. + + Args: + user_id (str) + alias (str) + + Returns: + boolean + """ + + # Note: The regexes are anchored at both ends + if not self._user_id_regex.match(user_id): + return False + + if not self._alias_regex.match(alias): + return False + + return True diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 57d4665e84..d40e4b8591 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -15,6 +15,8 @@ import logging +from six.moves import urllib + from canonicaljson import json from twisted.internet import defer, reactor @@ -28,15 +30,15 @@ from synapse.util import logcontext logger = logging.getLogger(__name__) -KEY_API_V1 = b"/_matrix/key/v1/" +KEY_API_V2 = "/_matrix/key/v2/server/%s" @defer.inlineCallbacks -def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): +def fetch_server_key(server_name, tls_client_options_factory, key_id): """Fetch the keys for a remote server.""" factory = SynapseKeyClientFactory() - factory.path = path + factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), ) factory.host = server_name endpoint = matrix_federation_endpoint( reactor, server_name, tls_client_options_factory, timeout=30 @@ -55,7 +57,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): raise IOError("Cannot get key for %r" % server_name) except (ConnectError, DomainError) as e: logger.warn("Error getting key for %r: %s", server_name, e) - except Exception as e: + except Exception: logger.exception("Error getting key for %r", server_name) raise IOError("Cannot get key for %r" % server_name) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d89f94c219..515ebbc148 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd. +# Copyright 2017, 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,6 @@ import hashlib import logging from collections import namedtuple -from six.moves import urllib - from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -395,32 +393,13 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_server(self, server_name_and_key_ids): - @defer.inlineCallbacks - def get_key(server_name, key_ids): - keys = None - try: - keys = yield self.get_server_verify_key_v2_direct( - server_name, key_ids - ) - except Exception as e: - logger.info( - "Unable to get key %r for %r directly: %s %s", - key_ids, server_name, - type(e).__name__, str(e), - ) - - if not keys: - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) - - keys = {server_name: keys} - - defer.returnValue(keys) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - run_in_background(get_key, server_name, key_ids) + run_in_background( + self.get_server_verify_key_v2_direct, + server_name, + key_ids, + ) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, @@ -525,10 +504,7 @@ class Keyring(object): continue (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_client_options_factory, - path=("/_matrix/key/v2/server/%s" % ( - urllib.parse.quote(requested_key_id), - )).encode("ascii"), + server_name, self.hs.tls_client_options_factory, requested_key_id ) if (u"signatures" not in response @@ -657,78 +633,6 @@ class Keyring(object): defer.returnValue(results) - @defer.inlineCallbacks - def get_server_verify_key_v1_direct(self, server_name, key_ids): - """Finds a verification key for the server with one of the key ids. - Args: - server_name (str): The name of the server to fetch a key for. - keys_ids (list of str): The key_ids to check for. - """ - - # Try to fetch the key from the remote server. - - (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_client_options_factory - ) - - # Check the response. - - x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, tls_certificate - ) - - if ("signatures" not in response - or server_name not in response["signatures"]): - raise KeyLookupError("Key response not signed by remote server") - - if "tls_certificate" not in response: - raise KeyLookupError("Key response missing TLS certificate") - - tls_certificate_b64 = response["tls_certificate"] - - if encode_base64(x509_certificate_bytes) != tls_certificate_b64: - raise KeyLookupError("TLS certificate doesn't match") - - # Cache the result in the datastore. - - time_now_ms = self.clock.time_msec() - - verify_keys = {} - for key_id, key_base64 in response["verify_keys"].items(): - if is_signing_algorithm_supported(key_id): - key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_key.time_added = time_now_ms - verify_keys[key_id] = verify_key - - for key_id in response["signatures"][server_name]: - if key_id not in response["verify_keys"]: - raise KeyLookupError( - "Key response must include verification keys for all" - " signatures" - ) - if key_id in verify_keys: - verify_signed_json( - response, - server_name, - verify_keys[key_id] - ) - - yield self.store.store_server_certificate( - server_name, - server_name, - time_now_ms, - tls_certificate, - ) - - yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=verify_keys, - ) - - defer.returnValue(verify_keys) - def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server Args: diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 6baeccca38..d4d4474847 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -98,9 +98,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise SynapseError( + raise AuthError( 403, - "Room %r does not exist" % (event.room_id,) + "No create event in auth events", ) creating_domain = get_domain_from_id(event.room_id) @@ -155,10 +155,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True): if user_level < invite_level: raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) + 403, "You don't have permission to invite users", ) else: logger.debug("Allowing! %s", event) @@ -305,7 +302,7 @@ def _is_membership_change_allowed(event, auth_events): if user_level < invite_level: raise AuthError( - 403, "You cannot invite user %s." % target_user_id + 403, "You don't have permission to invite users", ) elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: @@ -693,7 +690,7 @@ def auth_types_for_event(event): auth_types = [] auth_types.append((EventTypes.PowerLevels, "", )) - auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Member, event.sender, )) auth_types.append((EventTypes.Create, "", )) if event.type == EventTypes.Member: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index b782af6308..12f1eb0a3e 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -13,15 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from distutils.util import strtobool + import six from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents -# bugs where we accidentally share e.g. signature dicts. However, converting -# a dict to frozen_dicts is expensive. -USE_FROZEN_DICTS = True +# bugs where we accidentally share e.g. signature dicts. However, converting a +# dict to frozen_dicts is expensive. +# +# NOTE: This is overridden by the configuration by the Synapse worker apps, but +# for the sake of tests, it is set here while it cannot be configured on the +# homeserver object itself. +USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) class _EventInternalMetadata(object): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fe67b2ff42..d05ed91d64 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -66,6 +66,14 @@ class FederationClient(FederationBase): self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() + self._get_pdu_cache = ExpiringCache( + cache_name="get_pdu_cache", + clock=self._clock, + max_len=1000, + expiry_ms=120 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -82,17 +90,6 @@ class FederationClient(FederationBase): if destination_dict: self.pdu_destination_tried[event_id] = destination_dict - def start_get_pdu_cache(self): - self._get_pdu_cache = ExpiringCache( - cache_name="get_pdu_cache", - clock=self._clock, - max_len=1000, - expiry_ms=120 * 1000, - reset_expiry_on_get=False, - ) - - self._get_pdu_cache.start() - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False, ignore_backoff=False): @@ -212,8 +209,6 @@ class FederationClient(FederationBase): Will attempt to get the PDU from each destination in the list until one succeeds. - This will persist the PDU locally upon receipt. - Args: destinations (list): Which home servers to query event_id (str): event to fetch @@ -229,10 +224,9 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. - if self._get_pdu_cache: - ev = self._get_pdu_cache.get(event_id) - if ev: - defer.returnValue(ev) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) @@ -285,7 +279,7 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None and signed_pdu: + if signed_pdu: self._get_pdu_cache[event_id] = signed_pdu defer.returnValue(signed_pdu) @@ -293,8 +287,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. + """Requests all of the room state at a given event from a remote home server. Args: destination (str): The remote homeserver to query for the state. @@ -302,9 +295,10 @@ class FederationClient(FederationBase): event_id (str): The id of the event we want the state at. Returns: - Deferred: Results in a list of PDUs. + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. """ - try: # First we try and ask for just the IDs, as thats far quicker if # we have most of the state and auth_chain already. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index dbee404ea7..fa2cc550e2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re import six from six import iteritems @@ -44,8 +43,10 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import get_domain_from_id +from synapse.util import glob_to_regex from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logcontext import nested_logging_context from synapse.util.logutils import log_function # when processing incoming transactions, we try to handle multiple rooms in @@ -187,21 +188,22 @@ class FederationServer(FederationBase): for pdu in pdus_by_room[room_id]: event_id = pdu.event_id - try: - yield self._handle_received_pdu( - origin, pdu - ) - pdu_results[event_id] = {} - except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s: %s", - event_id, f.getTraceback().rstrip(), - ) + with nested_logging_context(event_id): + try: + yield self._handle_received_pdu( + origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), + ) yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), @@ -322,11 +324,6 @@ class FederationServer(FederationBase): defer.returnValue((404, "")) @defer.inlineCallbacks - @log_function - def on_pull_request(self, origin, versions): - raise NotImplementedError("Pull transactions not implemented") - - @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() resp = yield self.registry.on_query(query_type, args) @@ -505,19 +502,19 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," - " limit: %d, min_depth: %d", - earliest_events, latest_events, limit, min_depth + " limit: %d", + earliest_events, latest_events, limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, min_depth + origin, room_id, earliest_events, latest_events, limit, ) if len(missing_events) < 5: @@ -618,7 +615,7 @@ class FederationServer(FederationBase): ) yield self.handler.on_receive_pdu( - origin, pdu, get_missing=True, sent_to_us_directly=True, + origin, pdu, sent_to_us_directly=True, ) def __str__(self): @@ -727,22 +724,10 @@ def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) return False - regex = _glob_to_regex(acl_entry) + regex = glob_to_regex(acl_entry) return regex.match(server_name) -def _glob_to_regex(glob): - res = '' - for c in glob: - if c == '*': - res = res + '.*' - elif c == '?': - res = res + '.' - else: - res = res + re.escape(c) - return re.compile(res + "\\Z", re.IGNORECASE) - - class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. @@ -798,7 +783,7 @@ class FederationHandlerRegistry(object): yield handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: + except Exception: logger.exception("Failed to handle edu %r", edu_type) def on_query(self, query_type, args): diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 8cbf8c4f7f..3fdd63be95 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -137,26 +137,6 @@ class TransactionQueue(object): self._processing_pending_presence = False - def can_send_to(self, destination): - """Can we send messages to the given server? - - We can't send messages to ourselves. If we are running on localhost - then we can only federation with other servers running on localhost. - Otherwise we only federate with servers on a public domain. - - Args: - destination(str): The server we are possibly trying to send to. - Returns: - bool: True if we can send to the server. - """ - - if destination == self.server_name: - return False - if self.server_name.startswith("localhost"): - return destination.startswith("localhost") - else: - return not destination.startswith("localhost") - def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -279,10 +259,7 @@ class TransactionQueue(object): self._order += 1 destinations = set(destinations) - destinations = set( - dest for dest in destinations if self.can_send_to(dest) - ) - + destinations.discard(self.server_name) logger.debug("Sending to: %s", str(destinations)) if not destinations: @@ -358,7 +335,7 @@ class TransactionQueue(object): for destinations, states in hosts_and_states: for destination in destinations: - if not self.can_send_to(destination): + if destination == self.server_name: continue self.pending_presence_by_dest.setdefault( @@ -377,7 +354,8 @@ class TransactionQueue(object): content=content, ) - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") return sent_edus_counter.inc() @@ -392,10 +370,8 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) def send_device_messages(self, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending device update to ourselves") return self._attempt_new_transaction(destination) @@ -657,14 +633,6 @@ class TransactionQueue(object): transaction, json_data_cb ) code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) except HttpResponseException as e: code = e.code response = e.response @@ -681,19 +649,24 @@ class TransactionQueue(object): destination, txn_id, code ) - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - yield self.transaction_actions.delivered( transaction, code, response ) - logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) - if code != 200: + if code == 200: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "TX [%s] {%s} Remote returned error for %s: %s", + destination, txn_id, e_id, r, + ) + else: for p in pdus: - logger.info( - "Failed to send event %s to %s", p.event_id, destination + logger.warn( + "TX [%s] {%s} Failed to send event %s", + destination, txn_id, p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 2ab973d6c8..edba5a9808 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,9 +143,17 @@ class TransportLayerClient(object): transaction (Transaction) Returns: - Deferred: Results of the deferred is a tuple in the form of - (response_code, response_body) where the response_body is a - python dict decoded from json + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + + Fails with ``FederationDeniedError`` if this destination + is not on our federation whitelist """ logger.debug( "send_data dest=%s, txid=%s", @@ -170,11 +178,6 @@ class TransportLayerClient(object): backoff_on_404=True, # If we get a 404 the other side has gone ) - logger.debug( - "send_data dest=%s, txid=%s, got response: 200", - transaction.destination, transaction.transaction_id, - ) - defer.returnValue(response) @defer.inlineCallbacks diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2f874b4838..3553f418f1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -362,14 +362,6 @@ class FederationSendServlet(BaseFederationServlet): defer.returnValue((code, response)) -class FederationPullServlet(BaseFederationServlet): - PATH = "/pull/" - - # This is for when someone asks us for everything since version X - def on_GET(self, origin, content, query): - return self.handler.on_pull_request(query["origin"][0], query["v"]) - - class FederationEventServlet(BaseFederationServlet): PATH = "/event/(?P<event_id>[^/]*)/" @@ -560,7 +552,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): @defer.inlineCallbacks def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) - min_depth = int(content.get("min_depth", 0)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) @@ -569,7 +560,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, - min_depth=min_depth, limit=limit, ) @@ -1263,7 +1253,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, - FederationPullServlet, FederationEventServlet, FederationStateServlet, FederationStateIdsServlet, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f0f89af7dc..17eedf4dbf 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -28,6 +28,7 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import log_failure from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.metrics import Measure @@ -36,17 +37,6 @@ logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") -def log_failure(failure): - logger.error( - "Application Services Failure", - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject() - ) - ) - - class ApplicationServicesHandler(object): def __init__(self, hs): @@ -112,7 +102,10 @@ class ApplicationServicesHandler(object): if not self.started_scheduler: def start_scheduler(): - return self.scheduler.start().addErrback(log_failure) + return self.scheduler.start().addErrback( + log_failure, "Application Services Failure", + ) + run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2a5eab124f..85fc1fc525 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -22,7 +22,7 @@ import bcrypt import pymacaroons from canonicaljson import json -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils @@ -37,8 +37,8 @@ from synapse.api.errors import ( ) from synapse.module_api import ModuleApi from synapse.types import UserID +from synapse.util import logcontext from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable from ._base import BaseHandler @@ -59,6 +59,7 @@ class AuthHandler(BaseHandler): LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.MSISDN: self._check_msisdn, LoginType.DUMMY: self._check_dummy_auth, + LoginType.TERMS: self._check_terms_auth, } self.bcrypt_rounds = hs.config.bcrypt_rounds @@ -431,6 +432,9 @@ class AuthHandler(BaseHandler): def _check_dummy_auth(self, authdict, _): return defer.succeed(True) + def _check_terms_auth(self, authdict, _): + return defer.succeed(True) + @defer.inlineCallbacks def _check_threepid(self, medium, authdict): if 'threepid_creds' not in authdict: @@ -462,6 +466,22 @@ class AuthHandler(BaseHandler): def _get_params_recaptcha(self): return {"public_key": self.hs.config.recaptcha_public_key} + def _get_params_terms(self): + return { + "policies": { + "privacy_policy": { + "version": self.hs.config.user_consent_version, + "en": { + "name": "Privacy Policy", + "url": "%s/_matrix/consent?v=%s" % ( + self.hs.config.public_baseurl, + self.hs.config.user_consent_version, + ), + }, + }, + }, + } + def _auth_dict_for_flows(self, flows, session): public_flows = [] for f in flows: @@ -469,6 +489,7 @@ class AuthHandler(BaseHandler): get_params = { LoginType.RECAPTCHA: self._get_params_recaptcha, + LoginType.TERMS: self._get_params_terms, } params = {} @@ -884,11 +905,7 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode('ascii') - return make_deferred_yieldable( - threads.deferToThreadPool( - self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash - ), - ) + return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -913,13 +930,7 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode('ascii') - return make_deferred_yieldable( - threads.deferToThreadPool( - self.hs.get_reactor(), - self.hs.get_reactor().getThreadPool(), - _do_validate_hash, - ), - ) + return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: return defer.succeed(False) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index b078df4a76..75fe50c42c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -17,8 +17,8 @@ import logging from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester -from synapse.util.logcontext import run_in_background from ._base import BaseHandler @@ -121,7 +121,7 @@ class DeactivateAccountHandler(BaseHandler): None """ if not self._user_parter_running: - run_in_background(self._user_parter_loop) + run_as_background_process("user_parter_loop", self._user_parter_loop) @defer.inlineCallbacks def _user_parter_loop(self): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ef866da1b6..0699731c13 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -20,7 +20,14 @@ import string from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + NotFoundError, + StoreError, + SynapseError, +) from synapse.types import RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -36,6 +43,7 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.config = hs.config self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -73,43 +81,96 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def create_association(self, user_id, room_alias, room_id, servers=None): - # association creation for human users - # TODO(erikj): Do user auth. + def create_association(self, requester, room_alias, room_id, servers=None, + send_event=True): + """Attempt to create a new alias - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise SynapseError( - 403, "This user is not permitted to create this alias", - ) + Args: + requester (Requester) + room_alias (RoomAlias) + room_id (str) + servers (list[str]|None): List of servers that others servers + should try and join via + send_event (bool): Whether to send an updated m.room.aliases event - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) - if not can_create: - raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + Returns: + Deferred + """ + + user_id = requester.user.to_string() + + service = requester.app_service + if service: + if not service.is_interested_in_alias(room_alias.to_string()): + raise SynapseError( + 400, "This application service has not reserved" + " this kind of alias.", errcode=Codes.EXCLUSIVE + ) + else: + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + raise AuthError( + 403, "This user is not permitted to create this alias", + ) + + if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()): + # Lets just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError( + 403, "Not allowed to create alias", + ) + + can_create = yield self.can_modify_alias( + room_alias, + user_id=user_id ) + if not can_create: + raise AuthError( + 400, "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) + yield self._create_association(room_alias, room_id, servers, creator=user_id) + if send_event: + yield self.send_room_alias_update_event( + requester, + room_id + ) @defer.inlineCallbacks - def create_appservice_association(self, service, room_alias, room_id, - servers=None): - if not service.is_interested_in_alias(room_alias.to_string()): - raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE - ) + def delete_association(self, requester, room_alias, send_event=True): + """Remove an alias from the directory - # association creation for app services - yield self._create_association(room_alias, room_id, servers) + (this is only meant for human users; AS users should call + delete_appservice_association) - @defer.inlineCallbacks - def delete_association(self, requester, user_id, room_alias): - # association deletion for human users + Args: + requester (Requester): + room_alias (RoomAlias): + send_event (bool): Whether to send an updated m.room.aliases event. + Note that, if we delete the canonical alias, we will always attempt + to send an m.room.canonical_alias event + + Returns: + Deferred[unicode]: room id that the alias used to point to + + Raises: + NotFoundError: if the alias doesn't exist + + AuthError: if the user doesn't have perms to delete the alias (ie, the user + is neither the creator of the alias, nor a server admin. + + SynapseError: if the alias belongs to an AS + """ + user_id = requester.user.to_string() + + try: + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown room alias") + raise - can_delete = yield self._user_can_delete_alias(room_alias, user_id) if not can_delete: raise AuthError( 403, "You don't have permission to delete the alias.", @@ -128,11 +189,11 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - yield self.send_room_alias_update_event( - requester, - requester.user.to_string(), - room_id - ) + if send_event: + yield self.send_room_alias_update_event( + requester, + room_id + ) yield self._update_canonical_alias( requester, @@ -248,7 +309,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def send_room_alias_update_event(self, requester, user_id, room_id): + def send_room_alias_update_event(self, requester, room_id): aliases = yield self.store.get_aliases_for_room(room_id) yield self.event_creation_handler.create_and_send_nonmember_event( @@ -257,7 +318,7 @@ class DirectoryHandler(BaseHandler): "type": EventTypes.Aliases, "state_key": self.hs.hostname, "room_id": room_id, - "sender": user_id, + "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, ratelimit=False @@ -320,7 +381,7 @@ class DirectoryHandler(BaseHandler): def _user_can_delete_alias(self, alias, user_id): creator = yield self.store.get_room_alias_creator(alias.to_string()) - if creator and creator == user_id: + if creator is not None and creator == user_id: defer.returnValue(True) is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 578e9250fb..9dc46aa15f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -341,7 +341,7 @@ class E2eKeysHandler(object): def _exception_to_failure(e): if isinstance(e, CodeMessageException): return { - "status": e.code, "message": e.message, + "status": e.code, "message": str(e), } if isinstance(e, NotRetryingDestination): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py new file mode 100644 index 0000000000..5edb3cfe04 --- /dev/null +++ b/synapse/handlers/e2e_room_keys.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.api.errors import RoomKeysVersionError, StoreError, SynapseError +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +class E2eRoomKeysHandler(object): + """ + Implements an optional realtime backup mechanism for encrypted E2E megolm room keys. + This gives a way for users to store and recover their megolm keys if they lose all + their clients. It should also extend easily to future room key mechanisms. + The actual payload of the encrypted keys is completely opaque to the handler. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + + # Used to lock whenever a client is uploading key data. This prevents collisions + # between clients trying to upload the details of a new session, given all + # clients belonging to a user will receive and try to upload a new session at + # roughly the same time. Also used to lock out uploads when the key is being + # changed. + self._upload_linearizer = Linearizer("upload_room_keys_lock") + + @defer.inlineCallbacks + def get_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + See EndToEndRoomKeyStore.get_e2e_room_keys for full details. + + Args: + user_id(str): the user whose keys we're getting + version(str): the version ID of the backup we're getting keys from + room_id(string): room ID to get keys for, for None to get keys for all rooms + session_id(string): session ID to get keys for, for None to get keys for all + sessions + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + # we deliberately take the lock to get keys so that changing the version + # works atomically + with (yield self._upload_linearizer.queue(user_id)): + results = yield self.store.get_e2e_room_keys( + user_id, version, room_id, session_id + ) + + if results['rooms'] == {}: + raise SynapseError(404, "No room_keys found") + + defer.returnValue(results) + + @defer.inlineCallbacks + def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. + + Args: + user_id(str): the user whose backup we're deleting + version(str): the version ID of the backup we're deleting + room_id(string): room ID to delete keys for, for None to delete keys for all + rooms + session_id(string): session ID to delete keys for, for None to delete keys + for all sessions + Returns: + A deferred of the deletion transaction + """ + + # lock for consistency with uploading + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + + @defer.inlineCallbacks + def upload_room_keys(self, user_id, version, room_keys): + """Bulk upload a list of room keys into a given backup version, asserting + that the given version is the current backup version. room_keys are merged + into the current backup as described in RoomKeysServlet.on_PUT(). + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_keys(dict): a nested dict describing the room_keys we're setting: + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + + Raises: + SynapseError: with code 404 if there are no versions defined + RoomKeysVersionError: if the uploaded version is not the current version + """ + + # TODO: Validate the JSON to make sure it has the right keys. + + # XXX: perhaps we should use a finer grained lock here? + with (yield self._upload_linearizer.queue(user_id)): + + # Check that the version we're trying to upload is the current version + try: + version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise + + if version_info['version'] != version: + # Check that the version we're trying to upload actually exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version, + ) + # if we get this far, the version must exist + raise RoomKeysVersionError(current_version=version_info['version']) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise + + # go through the room_keys. + # XXX: this should/could be done concurrently, given we're in a lock. + for room_id, room in iteritems(room_keys['rooms']): + for session_id, session in iteritems(room['sessions']): + yield self._upload_room_key( + user_id, version, room_id, session_id, session + ) + + @defer.inlineCallbacks + def _upload_room_key(self, user_id, version, room_id, session_id, room_key): + """Upload a given room_key for a given room and session into a given + version of the backup. Merges the key with any which might already exist. + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + """ + + # get the room_key for this particular row + current_room_key = None + try: + current_room_key = yield self.store.get_e2e_room_key( + user_id, version, room_id, session_id + ) + except StoreError as e: + if e.code == 404: + pass + else: + raise + + if self._should_replace_room_key(current_room_key, room_key): + yield self.store.set_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + + @staticmethod + def _should_replace_room_key(current_room_key, room_key): + """ + Determine whether to replace a given current_room_key (if any) + with a newly uploaded room_key backup + + Args: + current_room_key (dict): Optional, the current room_key dict if any + room_key (dict): The new room_key dict which may or may not be fit to + replace the current_room_key + + Returns: + True if current_room_key should be replaced by room_key in the backup + """ + + if current_room_key: + # spelt out with if/elifs rather than nested boolean expressions + # purely for legibility. + + if room_key['is_verified'] and not current_room_key['is_verified']: + return True + elif ( + room_key['first_message_index'] < + current_room_key['first_message_index'] + ): + return True + elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + return True + else: + return False + return True + + @defer.inlineCallbacks + def create_version(self, user_id, version_info): + """Create a new backup version. This automatically becomes the new + backup version for the user's keys; previous backups will no longer be + writeable to. + + Args: + user_id(str): the user whose backup version we're creating + version_info(dict): metadata about the new version being created + + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + Returns: + A deferred of a string that gives the new version number. + """ + + # TODO: Validate the JSON to make sure it has the right keys. + + # lock everyone out until we've switched version + with (yield self._upload_linearizer.queue(user_id)): + new_version = yield self.store.create_e2e_room_keys_version( + user_id, version_info + ) + defer.returnValue(new_version) + + @defer.inlineCallbacks + def get_version_info(self, user_id, version=None): + """Get the info about a given version of the user's backup + + Args: + user_id(str): the user whose current backup version we're querying + version(str): Optional; if None gives the most recent version + otherwise a historical one. + Raises: + StoreError: code 404 if the requested backup version doesn't exist + Returns: + A deferred of a info dict that gives the info about the new version. + + { + "version": "1234", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + + with (yield self._upload_linearizer.queue(user_id)): + res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + defer.returnValue(res) + + @defer.inlineCallbacks + def delete_version(self, user_id, version=None): + """Deletes a given version of the user's e2e_room_keys backup + + Args: + user_id(str): the user whose current backup version we're deleting + version(str): the version id of the backup being deleted + Raises: + StoreError: code 404 if this backup version doesn't exist + """ + + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_keys_version(user_id, version) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0c68e8a472..cd5b9bbb19 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -18,7 +18,6 @@ import itertools import logging -import sys import six from six import iteritems, itervalues @@ -54,7 +53,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet -from synapse.state import resolve_events_with_factory +from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError from synapse.util.async_helpers import Linearizer @@ -69,6 +68,27 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +def shortstr(iterable, maxitems=5): + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable (Iterable): iterable to truncate + maxitems (int): number of items to return before truncating + + Returns: + unicode + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + + class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -85,7 +105,7 @@ class FederationHandler(BaseHandler): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -114,9 +134,8 @@ class FederationHandler(BaseHandler): self._room_pdu_linearizer = Linearizer("fed_room_pdu") @defer.inlineCallbacks - @log_function def on_receive_pdu( - self, origin, pdu, get_missing=True, sent_to_us_directly=False, + self, origin, pdu, sent_to_us_directly=False, ): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -125,14 +144,23 @@ class FederationHandler(BaseHandler): origin (str): server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu (FrozenEvent): received PDU - get_missing (bool): True if we should fetch missing prev_events + sent_to_us_directly (bool): True if this event was pushed to us; False if + we pulled it as the result of a missing prev_event. Returns (Deferred): completes with None """ + room_id = pdu.room_id + event_id = pdu.event_id + + logger.info( + "[%s %s] handling received PDU: %s", + room_id, event_id, pdu, + ) + # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - pdu.event_id, + event_id, allow_none=True, allow_rejected=True, ) @@ -147,7 +175,7 @@ class FederationHandler(BaseHandler): ) ) if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) + logger.debug("[%s %s]: Already seen pdu", room_id, event_id) return # do some initial sanity-checking of the event. In particular, make @@ -156,6 +184,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: + logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) raise FederationError( "ERROR", err.code, @@ -165,10 +194,12 @@ class FederationHandler(BaseHandler): # If we are currently in the process of joining this room, then we # queue up events for later processing. - if pdu.room_id in self.room_queues: - logger.info("Ignoring PDU %s for room %s from %s for now; join " - "in progress", pdu.event_id, pdu.room_id, origin) - self.room_queues[pdu.room_id].append((pdu, origin)) + if room_id in self.room_queues: + logger.info( + "[%s %s] Queuing PDU from %s for now: join in progress", + room_id, event_id, origin, + ) + self.room_queues[room_id].append((pdu, origin)) return # If we're no longer in the room just ditch the event entirely. This @@ -179,7 +210,7 @@ class FederationHandler(BaseHandler): # we should check if we *are* in fact in the room. If we are then we # can magically rejoin the room. is_in_room = yield self.auth.check_host_in_room( - pdu.room_id, + room_id, self.server_name ) if not is_in_room: @@ -188,8 +219,8 @@ class FederationHandler(BaseHandler): ) if was_in_room: logger.info( - "Ignoring PDU %s for room %s from %s as we've left the room!", - pdu.event_id, pdu.room_id, origin, + "[%s %s] Ignoring PDU from %s as we've left the room", + room_id, event_id, origin, ) defer.returnValue(None) @@ -204,8 +235,8 @@ class FederationHandler(BaseHandler): ) logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth + "[%s %s] min_depth: %d", + room_id, event_id, min_depth, ) prevs = {e_id for e_id, _ in pdu.prev_events} @@ -218,17 +249,18 @@ class FederationHandler(BaseHandler): # send to the clients. pdu.internal_metadata.outlier = True elif min_depth and pdu.depth > min_depth: - if get_missing and prevs - seen: + missing_prevs = prevs - seen + if sent_to_us_directly and missing_prevs: # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( - "Acquiring lock for room %r to fetch %d missing events: %r...", - pdu.room_id, len(prevs - seen), list(prevs - seen)[:5], + "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "Acquired lock for room %r to fetch %d missing events", - pdu.room_id, len(prevs - seen), + "[%s %s] Acquired room lock to fetch %d missing prev_events", + room_id, event_id, len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -241,69 +273,150 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( - "Found all missing prev events for %s", pdu.event_id + "[%s %s] Found all missing prev_events", + room_id, event_id, ) - elif prevs - seen: + elif missing_prevs: logger.info( - "Not fetching %d missing events for room %r,event %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, - list(prevs - seen)[:5], + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, event_id, len(missing_prevs), shortstr(missing_prevs), ) - if sent_to_us_directly and prevs - seen: - # If they have sent it to us directly, and the server - # isn't telling us about the auth events that it's - # made a message referencing, we explode - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - elif prevs - seen: - # Calculate the state of the previous events, and - # de-conflict them to find the current state. - state_groups = [] + if prevs - seen: + # We've still not been able to get all of the prev_events for this event. + # + # In this case, we need to fall back to asking another server in the + # federation for the state at this event. That's ok provided we then + # resolve the state against other bits of the DAG before using it (which + # will ensure that you can't just take over a room by sending an event, + # withholding its prev_events, and declaring yourself to be an admin in + # the subsequent state request). + # + # Now, if we're pulling this event as a missing prev_event, then clearly + # this event is not going to become the only forward-extremity and we are + # guaranteed to resolve its state against our existing forward + # extremities, so that should be fine. + # + # On the other hand, if this event was pushed to us, it is possible for + # it to become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very bad. + # Further, if the event was pushed to us, there is no excuse for us not to + # have all the prev_events. We therefore reject any such events. + # + # XXX this really feels like it could/should be merged with the above, + # but there is an interaction with min_depth that I'm not really + # following. + + if sent_to_us_directly: + logger.warn( + "[%s %s] Rejecting: failed to fetch %d prev events: %s", + room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. auth_chains = set() + event_map = { + event_id: pdu, + } try: # Get the state of the events we know about - ours = yield self.store.get_state_groups(pdu.room_id, list(seen)) - state_groups.append(ours) + ours = yield self.store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours # Ask the remote server for the states we don't # know about for p in prevs - seen: - state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, pdu.room_id, p - ) - ) - auth_chains.update(got_auth_chain) - state_group = {(x.type, x.state_key): x.event_id for x in state} - state_groups.append(state_group) - - # Resolve any conflicting state - def fetch(ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False + logger.info( + "[%s %s] Requesting state at missing prev_event %s", + room_id, event_id, p, ) - room_version = yield self.store.get_room_version(pdu.room_id) - state_map = yield resolve_events_with_factory( - room_version, state_groups, {pdu.event_id: pdu}, fetch + with logcontext.nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state, got_auth_chain = ( + yield self.federation_client.get_state_for_room( + origin, room_id, p, + ) + ) + + # we want the state *after* p; get_state_for_room returns the + # state *before* p. + remote_event = yield self.federation_client.get_pdu( + [origin], p, outlier=True, + ) + + if remote_event is None: + raise Exception( + "Unable to get missing prev_event %s" % (p, ) + ) + + if remote_event.is_state(): + remote_state.append(remote_event) + + # XXX hrm I'm not convinced that duplicate events will compare + # for equality, so I'm not sure this does what the author + # hoped. + auth_chains.update(got_auth_chain) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = yield self.store.get_room_version(room_id) + state_map = yield resolve_events_with_store( + room_version, state_maps, event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = yield self.store.get_events( + list(state_map.values()), + get_prev_content=False, + check_redacted=False, ) + event_map.update(evs) - state = (yield self.store.get_events(state_map.values())).values() + state = [ + event_map[e] for e in six.itervalues(state_map) + ] auth_chain = list(auth_chains) except Exception: + logger.warn( + "[%s %s] Error attempting to resolve state at missing " + "prev_events", + room_id, event_id, exc_info=True, + ) raise FederationError( "ERROR", 403, "We can't get valid state history.", - affected=pdu.event_id, + affected=event_id, ) yield self._process_received_pdu( @@ -322,15 +435,16 @@ class FederationHandler(BaseHandler): prevs (set(str)): List of event ids which we are missing min_depth (int): Minimum depth of events to return. """ - # We recalculate seen, since it may have changed. + + room_id = pdu.room_id + event_id = pdu.event_id + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room( - pdu.room_id - ) + latest = yield self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -338,8 +452,8 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "Missing %d events for room %r pdu %s: %r...", - len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5] + "[%s %s]: Requesting missing events between %s and %s", + room_id, event_id, shortstr(latest), event_id, ) # XXX: we set timeout to 10s to help workaround @@ -360,49 +474,88 @@ class FederationHandler(BaseHandler): # apparently. # # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out agressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the agressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timout to 60s and see what happens. missing_events = yield self.federation_client.get_missing_events( origin, - pdu.room_id, + room_id, earliest_events_ids=list(latest), latest_events=[pdu], limit=10, min_depth=min_depth, - timeout=10000, + timeout=60000, ) logger.info( - "Got %d events: %r...", - len(missing_events), [e.event_id for e in missing_events[:5]] + "[%s %s]: Got %d prev_events: %s", + room_id, event_id, len(missing_events), shortstr(missing_events), ) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) - for e in missing_events: - logger.info("Handling found event %s", e.event_id) - try: - yield self.on_receive_pdu( - origin, - e, - get_missing=False - ) - except FederationError as e: - if e.code == 403: - logger.warn("Event %s failed history check.") - else: - raise + for ev in missing_events: + logger.info( + "[%s %s] Handling received prev_event %s", + room_id, event_id, ev.event_id, + ) + with logcontext.nested_logging_context(ev.event_id): + try: + yield self.on_receive_pdu( + origin, + ev, + sent_to_us_directly=False, + ) + except FederationError as e: + if e.code == 403: + logger.warn( + "[%s %s] Received prev_event %s failed history check.", + room_id, event_id, ev.event_id, + ) + else: + raise - @log_function @defer.inlineCallbacks - def _process_received_pdu(self, origin, pdu, state, auth_chain): + def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ - event = pdu + room_id = event.room_id + event_id = event.event_id - logger.debug("Processing event: %s", event) + logger.debug( + "[%s %s] Processing event: %s", + room_id, event_id, event, + ) # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work @@ -411,15 +564,16 @@ class FederationHandler(BaseHandler): # event. if state and auth_chain and not event.internal_metadata.is_outlier(): is_in_room = yield self.auth.check_host_in_room( - event.room_id, + room_id, self.server_name ) else: is_in_room = True + if not is_in_room: logger.info( - "Got event for room we're not in: %r %r", - event.room_id, event.event_id + "[%s %s] Got event for room we're not in", + room_id, event_id, ) try: @@ -431,7 +585,7 @@ class FederationHandler(BaseHandler): "ERROR", e.code, e.msg, - affected=event.event_id, + affected=event_id, ) else: @@ -464,6 +618,10 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) + logger.info( + "[%s %s] persisting newly-received auth/state events %s", + room_id, event_id, [e["event"].event_id for e in event_infos] + ) yield self._handle_new_events(origin, event_infos) try: @@ -480,12 +638,12 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - room = yield self.store.get_room(event.room_id) + room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=event.room_id, + room_id=room_id, room_creator_user_id="", is_public=False, ) @@ -513,7 +671,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, event.room_id) + yield self.user_joined_room(user, room_id) @log_function @defer.inlineCallbacks @@ -1027,7 +1185,8 @@ class FederationHandler(BaseHandler): try: logger.info("Processing queued PDU %s which was received " "while we were joining %s", p.event_id, p.room_id) - yield self.on_receive_pdu(origin, p) + with logcontext.nested_logging_context(p.event_id): + yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( "Error handling queued PDU %s from %s: %s", @@ -1430,12 +1589,10 @@ class FederationHandler(BaseHandler): else: defer.returnValue(None) - @log_function def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) @defer.inlineCallbacks - @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, backfilled=False): context = yield self._prep_event( @@ -1444,6 +1601,9 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: if not event.internal_metadata.is_outlier() and not backfilled: yield self.action_generator.handle_push_actions_for_event( @@ -1454,15 +1614,13 @@ class FederationHandler(BaseHandler): [(event, context)], backfilled=backfilled, ) - except: # noqa: E722, as we reraise the exception this is fine. - tp, value, tb = sys.exc_info() - - logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - - six.reraise(tp, value, tb) + success = True + finally: + if not success: + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) defer.returnValue(context) @@ -1475,15 +1633,22 @@ class FederationHandler(BaseHandler): Notifies about the events where appropriate. """ - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self._prep_event, + + @defer.inlineCallbacks + def prep(ev_info): + event = ev_info["event"] + with logcontext.nested_logging_context(suffix=event.event_id): + res = yield self._prep_event( origin, - ev_info["event"], + event, state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) + defer.returnValue(res) + + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) for ev_info in event_infos ], consumeErrors=True, )) @@ -1635,8 +1800,8 @@ class FederationHandler(BaseHandler): ) except AuthError as e: logger.warn( - "Rejecting %s because %s", - event.event_id, e.msg + "[%s %s] Rejecting: %s", + event.room_id, event.event_id, e.msg ) context.rejected = RejectedReason.AUTH_ERROR @@ -1687,7 +1852,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): in_room = yield self.auth.check_host_in_room( room_id, origin @@ -1696,14 +1861,12 @@ class FederationHandler(BaseHandler): raise AuthError(403, "Host not in room.") limit = min(limit, 20) - min_depth = max(min_depth, 0) missing_events = yield self.store.get_missing_events( room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, limit=limit, - min_depth=min_depth, ) missing_events = yield filter_events_for_server( @@ -2357,7 +2520,7 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - self._notify_persisted_event(event, max_stream_id) + yield self._notify_persisted_event(event, max_stream_id) def _notify_persisted_event(self, event, max_stream_id): """Checks to see if notifier/pushers should be notified about the @@ -2390,7 +2553,7 @@ class FederationHandler(BaseHandler): extra_users=extra_users ) - self.pusher_pool.on_new_notifications( + return self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 53e5e2648b..173315af6c 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -20,7 +20,7 @@ from six import iteritems from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import HttpResponseException, SynapseError from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -37,9 +37,23 @@ def _create_rerouter(func_name): ) else: destination = get_domain_from_id(group_id) - return getattr(self.transport_client, func_name)( + d = getattr(self.transport_client, func_name)( destination, group_id, *args, **kwargs ) + + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. + def h(failure): + failure.trap(HttpResponseException) + e = failure.value + if e.code == 403: + raise e.to_synapse_error() + return failure + d.addErrback(h) + return d return f diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e009395207..563bb3cea3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler): room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( self.store.get_state_for_events, - [event.event_id], None, + [event.event_id], ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking): room_state = yield self.store.get_state_for_events( - [member_event_id], None + [member_event_id], ) room_state = room_state[member_event_id] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e484061cc0..a7cd779b02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -14,9 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys -import six from six import iteritems, itervalues, string_types from canonicaljson import encode_canonical_json, json @@ -37,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -82,7 +81,7 @@ class MessageHandler(object): elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] + [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -90,7 +89,7 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, types=None, filtered_types=None, + self, user_id, room_id, state_filter=StateFilter.all(), at_token=None, is_guest=False, ): """Retrieve all state events for a given room. If the user is @@ -102,13 +101,8 @@ class MessageHandler(object): Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. at_token(StreamToken|None): the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current @@ -141,7 +135,7 @@ class MessageHandler(object): event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], types, filtered_types=filtered_types, + [event.event_id], state_filter=state_filter, ) room_state = room_state[event.event_id] else: @@ -160,12 +154,12 @@ class MessageHandler(object): if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, types, filtered_types=filtered_types, + room_id, state_filter=state_filter, ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], types, filtered_types=filtered_types, + [membership_event_id], state_filter=state_filter, ) room_state = room_state[membership_event_id] @@ -433,6 +427,9 @@ class EventCreationHandler(object): if event.is_state(): prev_state = yield self.deduplicate_state_event(event, context) + logger.info( + "Not bothering to persist duplicate state event %s", event.event_id, + ) if prev_state is not None: defer.returnValue(prev_state) @@ -624,6 +621,9 @@ class EventCreationHandler(object): event, context ) + # reraise does not allow inlineCallbacks to preserve the stacktrace, so we + # hack around with a try/finally instead. + success = False try: # If we're a worker we need to hit out to the master. if self.config.worker_app: @@ -636,6 +636,7 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + success = True return yield self.persist_and_notify_client_event( @@ -645,17 +646,16 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) - except: # noqa: E722, as we reraise the exception this is fine. - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - tp, value, tb = sys.exc_info() - - run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, - ) - six.reraise(tp, value, tb) + success = True + finally: + if not success: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) @defer.inlineCallbacks def persist_and_notify_client_event( @@ -778,7 +778,7 @@ class EventCreationHandler(object): event, context=context ) - self.pusher_pool.on_new_notifications( + yield self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a155b6e938..43f81bd607 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -21,6 +21,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background @@ -255,16 +256,14 @@ class PaginationHandler(object): if event_filter and event_filter.lazy_load_members(): # TODO: remove redundant members - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in events - ) - ] + # FIXME: we also care about invite targets etc. + state_filter = StateFilter.from_types( + (EventTypes.Member, event.sender) + for event in events + ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, types=types, + events[0].event_id, state_filter=state_filter, ) if state_ids: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 75b8b7ce6a..1dfbde84fd 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -142,10 +142,8 @@ class BaseProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get displayname") raise - except Exception: - logger.exception("Failed to get displayname") - else: - defer.returnValue(result["displayname"]) + + defer.returnValue(result["displayname"]) @defer.inlineCallbacks def set_displayname(self, target_user, requester, new_displayname, by_admin=False): @@ -199,8 +197,6 @@ class BaseProfileHandler(BaseHandler): if e.code != 404: logger.exception("Failed to get avatar_url") raise - except Exception: - logger.exception("Failed to get avatar_url") defer.returnValue(result["avatar_url"]) @@ -278,7 +274,7 @@ class BaseProfileHandler(BaseHandler): except Exception as e: logger.warn( "Failed to update join event for room %s - %s", - room_id, str(e.message) + room_id, str(e) ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a6f3181f09..4c2690ba26 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -119,7 +119,7 @@ class ReceiptsHandler(BaseHandler): "receipt_key", max_batch_id, rooms=affected_room_ids ) # Note that the min here shouldn't be relied upon to be accurate. - self.hs.get_pusherpool().on_new_receipts( + yield self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 1e53f2c635..d2beb275cf 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,15 +220,42 @@ class RegistrationHandler(BaseHandler): # auto-join the user to any rooms we're supposed to dump them into fake_requester = create_requester(user_id) + + # try to create the room if we're the first user on the server + should_auto_create_rooms = False + if self.hs.config.autocreate_auto_join_rooms: + count = yield self.store.count_all_users() + should_auto_create_rooms = count == 1 + for r in self.hs.config.auto_join_rooms: try: - yield self._join_user_to_room(fake_requester, r) + if should_auto_create_rooms: + room_alias = RoomAlias.from_string(r) + if self.hs.hostname != room_alias.domain: + logger.warning( + 'Cannot create room alias %s, ' + 'it does not match server domain', + r, + ) + else: + # create room expects the localpart of the room alias + room_alias_localpart = room_alias.localpart + + # getting the RoomCreationHandler during init gives a dependency + # loop + yield self.hs.get_room_creation_handler().create_room( + fake_requester, + config={ + "preset": "public_chat", + "room_alias_name": room_alias_localpart + }, + ratelimit=False, + ) + else: + yield self._join_user_to_room(fake_requester, r) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - # We used to generate default identicons here, but nowadays - # we want clients to generate their own as part of their branding - # rather than there being consistent matrix-wide ones, so we don't. defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -534,4 +561,5 @@ class RegistrationHandler(BaseHandler): room_id=room_id, remote_room_hosts=remote_room_hosts, action="join", + ratelimit=False, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index c3f820b975..3928faa6e7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,7 +21,7 @@ import math import string from collections import OrderedDict -from six import string_types +from six import iteritems, string_types from twisted.internet import defer @@ -32,9 +32,11 @@ from synapse.api.constants import ( JoinRules, RoomCreationPreset, ) -from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils +from synapse.util.async_helpers import Linearizer from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -72,6 +74,334 @@ class RoomCreationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + + # linearizer to stop two upgrades happening at once + self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") + + @defer.inlineCallbacks + def upgrade_room(self, requester, old_room_id, new_version): + """Replace a room with a new room with a different version + + Args: + requester (synapse.types.Requester): the user requesting the upgrade + old_room_id (unicode): the id of the room to be replaced + new_version (unicode): the new room version to use + + Returns: + Deferred[unicode]: the new room id + """ + yield self.ratelimit(requester) + + user_id = requester.user.to_string() + + with (yield self._upgrade_linearizer.queue(old_room_id)): + # start by allocating a new room id + r = yield self.store.get_room(old_room_id) + if r is None: + raise NotFoundError("Unknown room id %s" % (old_room_id,)) + new_room_id = yield self._generate_room_id( + creator_id=user_id, is_public=r["is_public"], + ) + + logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) + + # we create and auth the tombstone event before properly creating the new + # room, to check our user has perms in the old room. + tombstone_event, tombstone_context = ( + yield self.event_creation_handler.create_event( + requester, { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, + } + }, + token_id=requester.access_token_id, + ) + ) + yield self.auth.check_from_context(tombstone_event, tombstone_context) + + yield self.clone_exiting_room( + requester, + old_room_id=old_room_id, + new_room_id=new_room_id, + new_room_version=new_version, + tombstone_event_id=tombstone_event.event_id, + ) + + # now send the tombstone + yield self.event_creation_handler.send_nonmember_event( + requester, tombstone_event, tombstone_context, + ) + + old_room_state = yield tombstone_context.get_current_state_ids(self.store) + + # update any aliases + yield self._move_aliases_to_new_room( + requester, old_room_id, new_room_id, old_room_state, + ) + + # and finally, shut down the PLs in the old room, and update them in the new + # room. + yield self._update_upgraded_room_pls( + requester, old_room_id, new_room_id, old_room_state, + ) + + defer.returnValue(new_room_id) + + @defer.inlineCallbacks + def _update_upgraded_room_pls( + self, requester, old_room_id, new_room_id, old_room_state, + ): + """Send updated power levels in both rooms after an upgrade + + Args: + requester (synapse.types.Requester): the user requesting the upgrade + old_room_id (unicode): the id of the room to be replaced + new_room_id (unicode): the id of the replacement room + old_room_state (dict[tuple[str, str], str]): the state map for the old room + + Returns: + Deferred + """ + old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, "")) + + if old_room_pl_event_id is None: + logger.warning( + "Not supported: upgrading a room with no PL event. Not setting PLs " + "in old room.", + ) + return + + old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + + # we try to stop regular users from speaking by setting the PL required + # to send regular events and invites to 'Moderator' level. That's normally + # 50, but if the default PL in a room is 50 or more, then we set the + # required PL above that. + + pl_content = dict(old_room_pl_state.content) + users_default = int(pl_content.get("users_default", 0)) + restricted_level = max(users_default + 1, 50) + + updated = False + for v in ("invite", "events_default"): + current = int(pl_content.get(v, 0)) + if current < restricted_level: + logger.info( + "Setting level for %s in %s to %i (was %i)", + v, old_room_id, restricted_level, current, + ) + pl_content[v] = restricted_level + updated = True + else: + logger.info( + "Not setting level for %s (already %i)", + v, current, + ) + + if updated: + try: + yield self.event_creation_handler.create_and_send_nonmember_event( + requester, { + "type": EventTypes.PowerLevels, + "state_key": '', + "room_id": old_room_id, + "sender": requester.user.to_string(), + "content": pl_content, + }, ratelimit=False, + ) + except AuthError as e: + logger.warning("Unable to update PLs in old room: %s", e) + + logger.info("Setting correct PLs in new room") + yield self.event_creation_handler.create_and_send_nonmember_event( + requester, { + "type": EventTypes.PowerLevels, + "state_key": '', + "room_id": new_room_id, + "sender": requester.user.to_string(), + "content": old_room_pl_state.content, + }, ratelimit=False, + ) + + @defer.inlineCallbacks + def clone_exiting_room( + self, requester, old_room_id, new_room_id, new_room_version, + tombstone_event_id, + ): + """Populate a new room based on an old room + + Args: + requester (synapse.types.Requester): the user requesting the upgrade + old_room_id (unicode): the id of the room to be replaced + new_room_id (unicode): the id to give the new room (should already have been + created with _gemerate_room_id()) + new_room_version (unicode): the new room version to use + tombstone_event_id (unicode|str): the ID of the tombstone event in the old + room. + Returns: + Deferred[None] + """ + user_id = requester.user.to_string() + + if not self.spam_checker.user_may_create_room(user_id): + raise SynapseError(403, "You are not permitted to create rooms") + + creation_content = { + "room_version": new_room_version, + "predecessor": { + "room_id": old_room_id, + "event_id": tombstone_event_id, + } + } + + initial_state = dict() + + types_to_copy = ( + (EventTypes.JoinRules, ""), + (EventTypes.Name, ""), + (EventTypes.Topic, ""), + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.GuestAccess, ""), + (EventTypes.RoomAvatar, ""), + ) + + old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_id, StateFilter.from_types(types_to_copy), + ) + # map from event_id to BaseEvent + old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + + for k, old_event_id in iteritems(old_room_state_ids): + old_event = old_room_state_events.get(old_event_id) + if old_event: + initial_state[k] = old_event.content + + yield self._send_events_for_new_room( + requester, + new_room_id, + + # we expect to override all the presets with initial_state, so this is + # somewhat arbitrary. + preset_config=RoomCreationPreset.PRIVATE_CHAT, + + invite_list=[], + initial_state=initial_state, + creation_content=creation_content, + ) + + # XXX invites/joins + # XXX 3pid invites + + @defer.inlineCallbacks + def _move_aliases_to_new_room( + self, requester, old_room_id, new_room_id, old_room_state, + ): + directory_handler = self.hs.get_handlers().directory_handler + + aliases = yield self.store.get_aliases_for_room(old_room_id) + + # check to see if we have a canonical alias. + canonical_alias = None + canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) + if canonical_alias_event_id: + canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) + if canonical_alias_event: + canonical_alias = canonical_alias_event.content.get("alias", "") + + # first we try to remove the aliases from the old room (we suppress sending + # the room_aliases event until the end). + # + # Note that we'll only be able to remove aliases that (a) aren't owned by an AS, + # and (b) unless the user is a server admin, which the user created. + # + # This is probably correct - given we don't allow such aliases to be deleted + # normally, it would be odd to allow it in the case of doing a room upgrade - + # but it makes the upgrade less effective, and you have to wonder why a room + # admin can't remove aliases that point to that room anyway. + # (cf https://github.com/matrix-org/synapse/issues/2360) + # + removed_aliases = [] + for alias_str in aliases: + alias = RoomAlias.from_string(alias_str) + try: + yield directory_handler.delete_association( + requester, alias, send_event=False, + ) + removed_aliases.append(alias_str) + except SynapseError as e: + logger.warning( + "Unable to remove alias %s from old room: %s", + alias, e, + ) + + # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest + # of this. + if not removed_aliases: + return + + try: + # this can fail if, for some reason, our user doesn't have perms to send + # m.room.aliases events in the old room (note that we've already checked that + # they have perms to send a tombstone event, so that's not terribly likely). + # + # If that happens, it's regrettable, but we should carry on: it's the same + # as when you remove an alias from the directory normally - it just means that + # the aliases event gets out of sync with the directory + # (cf https://github.com/vector-im/riot-web/issues/2369) + yield directory_handler.send_room_alias_update_event( + requester, old_room_id, + ) + except AuthError as e: + logger.warning( + "Failed to send updated alias event on old room: %s", e, + ) + + # we can now add any aliases we successfully removed to the new room. + for alias in removed_aliases: + try: + yield directory_handler.create_association( + requester, RoomAlias.from_string(alias), + new_room_id, servers=(self.hs.hostname, ), + send_event=False, + ) + logger.info("Moved alias %s to new room", alias) + except SynapseError as e: + # I'm not really expecting this to happen, but it could if the spam + # checking module decides it shouldn't, or similar. + logger.error( + "Error adding alias %s to new room: %s", + alias, e, + ) + + try: + if canonical_alias and (canonical_alias in removed_aliases): + yield self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": new_room_id, + "sender": requester.user.to_string(), + "content": {"alias": canonical_alias, }, + }, + ratelimit=False + ) + + yield directory_handler.send_room_alias_update_event( + requester, new_room_id, + ) + except SynapseError as e: + # again I'm not really expecting this to fail, but if it does, I'd rather + # we returned the new room to the client at this point. + logger.error( + "Unable to send updated alias events in new room: %s", e, + ) @defer.inlineCallbacks def create_room(self, requester, config, ratelimit=True, @@ -164,36 +494,16 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - # autogen room IDs and try to create it. We may clash, so just - # try a few times till one goes through, giving up eventually. - attempts = 0 - room_id = None - while attempts < 5: - try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID( - random_string, - self.hs.hostname, - ) - yield self.store.store_room( - room_id=gen_room_id.to_string(), - room_creator_user_id=user_id, - is_public=is_public - ) - room_id = gen_room_id.to_string() - break - except StoreError: - attempts += 1 - if not room_id: - raise StoreError(500, "Couldn't generate a room ID.") + room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public) if room_alias: directory_handler = self.hs.get_handlers().directory_handler yield directory_handler.create_association( - user_id=user_id, + requester=requester, room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], + send_event=False, ) preset_config = config.get( @@ -214,18 +524,15 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version - room_member_handler = self.hs.get_room_member_handler() - yield self._send_events_for_new_room( requester, room_id, - room_member_handler, preset_config=preset_config, invite_list=invite_list, initial_state=initial_state, creation_content=creation_content, room_alias=room_alias, - power_level_content_override=config.get("power_level_content_override", {}), + power_level_content_override=config.get("power_level_content_override"), creator_join_profile=creator_join_profile, ) @@ -261,7 +568,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -289,7 +596,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event( - requester, user_id, room_id + requester, room_id ) defer.returnValue(result) @@ -299,14 +606,13 @@ class RoomCreationHandler(BaseHandler): self, creator, # A Requester object. room_id, - room_member_handler, preset_config, invite_list, initial_state, creation_content, - room_alias, - power_level_content_override, - creator_join_profile, + room_alias=None, + power_level_content_override=None, + creator_join_profile=None, ): def create(etype, content, **kwargs): e = { @@ -322,6 +628,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def send(etype, content, **kwargs): event = create(etype, content, **kwargs) + logger.info("Sending %s in new room", etype) yield self.event_creation_handler.create_and_send_nonmember_event( creator, event, @@ -344,7 +651,8 @@ class RoomCreationHandler(BaseHandler): content=creation_content, ) - yield room_member_handler.update_membership( + logger.info("Sending %s in new room", EventTypes.Member) + yield self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -386,7 +694,8 @@ class RoomCreationHandler(BaseHandler): for invitee in invite_list: power_level_content["users"][invitee] = 100 - power_level_content.update(power_level_content_override) + if power_level_content_override: + power_level_content.update(power_level_content_override) yield send( etype=EventTypes.PowerLevels, @@ -425,6 +734,30 @@ class RoomCreationHandler(BaseHandler): content=content, ) + @defer.inlineCallbacks + def _generate_room_id(self, creator_id, is_public): + # autogen room IDs and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + try: + random_string = stringutils.random_string(18) + gen_room_id = RoomID( + random_string, + self.hs.hostname, + ).to_string() + if isinstance(gen_room_id, bytes): + gen_room_id = gen_room_id.decode('utf-8') + yield self.store.store_room( + room_id=gen_room_id, + room_creator_user_id=creator_id, + is_public=is_public, + ) + defer.returnValue(gen_room_id) + except StoreError: + attempts += 1 + raise StoreError(500, "Couldn't generate a room ID.") + class RoomContextHandler(object): def __init__(self, hs): @@ -488,23 +821,24 @@ class RoomContextHandler(object): else: last_event_id = event_id - types = None - filtered_types = None if event_filter and event_filter.lazy_load_members(): - members = set(ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], - )) - filtered_types = [EventTypes.Member] - types = [(EventTypes.Member, member) for member in members] + state_filter = StateFilter.from_lazy_load_member_list( + ev.sender + for ev in itertools.chain( + results["events_before"], + (results["event"],), + results["events_after"], + ) + ) + else: + state_filter = StateFilter.all() # XXX: why do we return the state as of the last event rather than the # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], types, filtered_types=filtered_types, + [last_event_id], state_filter=state_filter, ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 38e1737ec9..dc88620885 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from six import iteritems +from six import PY3, iteritems from six.moves import range import msgpack @@ -444,9 +444,16 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", ( @classmethod def from_token(cls, token): + if PY3: + # The argument raw=False is only available on new versions of + # msgpack, and only really needed on Python 3. Gate it behind + # a PY3 check to avoid causing issues on Debian-packaged versions. + decoded = msgpack.loads(decode_base64(token), raw=False) + else: + decoded = msgpack.loads(decode_base64(token)) return RoomListNextBatch(**{ cls.REVERSE_KEY_DICT[key]: val - for key, val in msgpack.loads(decode_base64(token)).items() + for key, val in decoded.items() }) def to_token(self): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f643619047..07fd3e82fc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -583,6 +583,11 @@ class RoomMemberHandler(object): room_id = mapping["room_id"] servers = mapping["servers"] + # put the server which owns the alias at the front of the server list. + if room_alias.domain in servers: + servers.remove(room_alias.domain) + servers.insert(0, room_alias.domain) + defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0c1d52fd11..80e7b15de8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -324,9 +325,12 @@ class SearchHandler(BaseHandler): else: last_event_id = event.event_id + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] + ) + state = yield self.store.get_state_for_event( - last_event_id, - types=[(EventTypes.Member, sender) for sender in senders] + last_event_id, state_filter ) res["profile_info"] = { diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c1257521c1..48b7431433 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,11 +20,14 @@ import logging from six import iteritems, itervalues +from prometheus_client import Counter + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -36,6 +39,19 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) + +# Counts the number of times we returned a non-empty sync. `type` is one of +# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is +# "true" or "false" depending on if the request asked for lazy loaded members or +# not. +non_empty_sync_counter = Counter( + "synapse_handlers_sync_nonempty_total", + "Count of non empty sync responses. type is initial_sync/full_state_sync" + "/incremental_sync. lazy_loaded indicates if lazy loaded members were " + "enabled for that request.", + ["type", "lazy_loaded"], +) + # Store the cache that tracks which lazy-loaded members have been sent to a given # client for no more than 30 minutes. LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 @@ -227,14 +243,16 @@ class SyncHandler(object): @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + if since_token is None: + sync_type = "initial_sync" + elif full_state: + sync_type = "full_state_sync" + else: + sync_type = "incremental_sync" + context = LoggingContext.current_context() if context: - if since_token is None: - context.tag = "initial_sync" - elif full_state: - context.tag = "full_state_sync" - else: - context.tag = "incremental_sync" + context.tag = sync_type if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling @@ -242,7 +260,6 @@ class SyncHandler(object): result = yield self.current_sync_for_user( sync_config, since_token, full_state=full_state, ) - defer.returnValue(result) else: def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) @@ -251,7 +268,15 @@ class SyncHandler(object): sync_config.user.to_string(), timeout, current_sync_callback, from_token=since_token, ) - defer.returnValue(result) + + if result: + if sync_config.filter_collection.lazy_load_members(): + lazy_loaded = "true" + else: + lazy_loaded = "false" + non_empty_sync_counter.labels(sync_type, lazy_loaded).inc() + + defer.returnValue(result) def current_sync_for_user(self, sync_config, since_token=None, full_state=False): @@ -445,25 +470,20 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event, types=None, filtered_types=None): + def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event Args: event(synapse.events.EventBase): event of interest - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, types, filtered_types=filtered_types, + event.event_id, state_filter=state_filter, ) if event.is_state(): state_ids = state_ids.copy() @@ -471,18 +491,14 @@ class SyncHandler(object): defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): + def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): """ Get the room state at a particular stream position Args: room_id(str): room for which to get state stream_position(StreamToken): point at which to get state - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) @@ -498,7 +514,7 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, types, filtered_types=filtered_types, + last_event, state_filter=state_filter, ) else: @@ -539,11 +555,12 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( - last_event.event_id, [ + last_event.event_id, + state_filter=StateFilter.from_types([ (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), (EventTypes.RoomAvatar, ''), - ] + ]), ) # this is heavily cached, thus: fast. @@ -568,20 +585,20 @@ class SyncHandler(object): # haven't been "deleted" by blatting {} over the top. room_avatar_id = state_ids.get((EventTypes.RoomAvatar, '')) if room_avatar_id: - room_avatar = yield self.store.get_event(room_avatar_id, allow_none=False) + room_avatar = yield self.store.get_event(room_avatar_id, allow_none=True) if room_avatar and room_avatar.content: # we have a room avatar; check to see if we can skip heroes # because the room has an explicit name or canonical alias: name_id = state_ids.get((EventTypes.Name, '')) if name_id: - name = yield self.store.get_event(name_id, allow_none=False) + name = yield self.store.get_event(name_id, allow_none=True) if name and name.content: defer.returnValue(summary) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) if canonical_alias_id: canonical_alias = yield self.store.get_event( - canonical_alias_id, allow_none=False, + canonical_alias_id, allow_none=Tue, ) if canonical_alias and canonical_alias.content: defer.returnValue(summary) @@ -701,8 +718,7 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): - types = None - filtered_types = None + members_to_fetch = None lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -713,16 +729,21 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - ) - ] + members_to_fetch = set( + event.sender # FIXME: we also care about invite targets etc. + for event in batch.events + ) + + if full_state: + # always make sure we LL ourselves so we know we're in the room + # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in #3840. + members_to_fetch.add(sync_config.user.to_string()) - # only apply the filtering to room members - filtered_types = [EventTypes.Member] + state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + else: + state_filter = StateFilter.all() timeline_state = { (event.type, event.state_key): event.event_id @@ -732,19 +753,17 @@ class SyncHandler(object): if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=now_token, + state_filter=state_filter, ) state_ids = current_state_ids @@ -758,8 +777,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) # for now, we disable LL for gappy syncs - see @@ -774,17 +792,15 @@ class SyncHandler(object): # members to just be ones which were timeline senders, which then ensures # all of the rest get included in the state block (if we need to know # about them). - types = None - filtered_types = None + state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=since_token, + state_filter=state_filter, ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = _calculate_state( @@ -798,7 +814,7 @@ class SyncHandler(object): else: state_ids = {} if lazy_load_members: - if types: + if members_to_fetch and batch.events: # We're returning an incremental sync, with no # "gap" since the previous sync, so normally there would be # no state to return. @@ -808,8 +824,12 @@ class SyncHandler(object): # timeline here, and then dedupe any redundant ones below. state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=None, # we only want members! + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), ) if lazy_load_members and not include_redundant_members: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 2d2d3d5a0d..a61bbf9392 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID, get_domain_from_id +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -62,17 +63,28 @@ class TypingHandler(object): self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} - # map room IDs to serial numbers - self._room_serials = {} self._latest_room_serial = 0 - # map room IDs to sets of users currently typing - self._room_typing = {} + self._reset() + + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial, + ) self.clock.looping_call( self._handle_timeouts, 5000, ) + def _reset(self): + """ + Reset the typing handler's data caches. + """ + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} + def _handle_timeouts(self): logger.info("Checking for typing timeouts") @@ -218,6 +230,7 @@ class TypingHandler(object): for domain in set(get_domain_from_id(u) for u in users): if domain != self.server_name: + logger.debug("sending typing update to %s", domain) self.federation.send_edu( destination=domain, edu_type="m.typing", @@ -274,19 +287,29 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial + self._typing_stream_change_cache.entity_has_changed( + member.room_id, self._latest_room_serial, + ) self.notifier.on_new_event( "typing_key", self._latest_room_serial, rooms=[member.room_id] ) def get_all_typing_updates(self, last_id, current_id): - # TODO: Work out a way to do this without scanning the entire state. if last_id == current_id: return [] + changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( + last_id, + ) + + if changed_rooms is None: + changed_rooms = self._room_serials + rows = [] - for room_id, serial in self._room_serials.items(): - if last_id < serial and serial <= current_id: + for room_id in changed_rooms: + serial = self._room_serials[room_id] + if last_id < serial <= current_id: typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d8413d6aa7..f11b430126 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -20,6 +20,7 @@ from six import iteritems from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure @@ -98,7 +99,6 @@ class UserDirectoryHandler(object): """ return self.store.search_user_dir(user_id, search_term, limit) - @defer.inlineCallbacks def notify_new_event(self): """Called when there may be more deltas to process """ @@ -108,11 +108,15 @@ class UserDirectoryHandler(object): if self._is_processing: return + @defer.inlineCallbacks + def process(): + try: + yield self._unsafe_process() + finally: + self._is_processing = False + self._is_processing = True - try: - yield self._unsafe_process() - finally: - self._is_processing = False + run_as_background_process("user_directory.notify_new_event", process) @defer.inlineCallbacks def handle_local_profile_change(self, user_id, profile): diff --git a/synapse/http/client.py b/synapse/http/client.py index ec339a92ad..3d05f83b8c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -43,7 +43,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async_helpers import add_timeout_to_deferred +from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable @@ -99,7 +99,7 @@ class SimpleHttpClient(object): request_deferred = treq.request( method, uri, agent=self.agent, data=data, headers=headers ) - add_timeout_to_deferred( + request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index b0c9369519..91025037a3 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -108,7 +108,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= Args: reactor: Twisted reactor. - destination (bytes): The name of the server to connect to. + destination (unicode): The name of the server to connect to. tls_client_options_factory (synapse.crypto.context_factory.ClientTLSOptionsFactory): Factory which generates TLS options for client connections. @@ -126,10 +126,17 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory= transport_endpoint = HostnameEndpoint default_port = 8008 else: + # the SNI string should be the same as the Host header, minus the port. + # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777, + # the Host header and SNI should therefore be the server_name of the remote + # server. + tls_options = tls_client_options_factory.get_options(domain) + def transport_endpoint(reactor, host, port, timeout): return wrapClientTLS( - tls_client_options_factory.get_options(host), - HostnameEndpoint(reactor, host, port, timeout=timeout)) + tls_options, + HostnameEndpoint(reactor, host, port, timeout=timeout), + ) default_port = 8448 if port is None: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 083484a687..24b6110c20 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,10 +17,12 @@ import cgi import logging import random import sys +from io import BytesIO from six import PY3, string_types from six.moves import urllib +import attr import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter @@ -28,8 +30,9 @@ from signedjson.sign import sign_json from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError +from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool +from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool from twisted.web.http_headers import Headers import synapse.metrics @@ -41,13 +44,11 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util import logcontext -from synapse.util.async_helpers import timeout_no_seriously +from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outbound_logger = logging.getLogger("synapse.http.outbound") outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", "", ["method"]) @@ -78,6 +79,99 @@ class MatrixFederationEndpointFactory(object): ) +_next_id = 1 + + +@attr.s +class MatrixFederationRequest(object): + method = attr.ib() + """HTTP method + :type: str + """ + + path = attr.ib() + """HTTP path + :type: str + """ + + destination = attr.ib() + """The remote server to send the HTTP request to. + :type: str""" + + json = attr.ib(default=None) + """JSON to send in the body. + :type: dict|None + """ + + json_callback = attr.ib(default=None) + """A callback to generate the JSON. + :type: func|None + """ + + query = attr.ib(default=None) + """Query arguments. + :type: dict|None + """ + + txn_id = attr.ib(default=None) + """Unique ID for this request (for logging) + :type: str|None + """ + + def __attrs_post_init__(self): + global _next_id + self.txn_id = "%s-O-%s" % (self.method, _next_id) + _next_id = (_next_id + 1) % (MAXINT - 1) + + def get_json(self): + if self.json_callback: + return self.json_callback() + return self.json + + +@defer.inlineCallbacks +def _handle_json_response(reactor, timeout_sec, request, response): + """ + Reads the JSON body of a response, with a timeout + + Args: + reactor (IReactor): twisted reactor, for the timeout + timeout_sec (float): number of seconds to wait for response to complete + request (MatrixFederationRequest): the request that triggered the response + response (IResponse): response to the request + + Returns: + dict: parsed JSON response + """ + try: + check_content_type_is_json(response.headers) + + d = treq.json_content(response) + d = timeout_deferred( + d, + timeout=timeout_sec, + reactor=reactor, + ) + + body = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) + raise + logger.info( + "{%s} [%s] Completed: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) + defer.returnValue(body) + + class MatrixFederationHttpClient(object): """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. @@ -101,41 +195,42 @@ class MatrixFederationHttpClient(object): ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string = hs.version_string.encode('ascii') - self._next_id = 1 + self.version_string_bytes = hs.version_string.encode('ascii') self.default_timeout = 60 - def _create_url(self, destination, path_bytes, param_bytes, query_bytes): - return urllib.parse.urlunparse( - (b"matrix", destination, path_bytes, param_bytes, query_bytes, b"") - ) + def schedule(x): + reactor.callLater(_EPSILON, x) + + self._cooperator = Cooperator(scheduler=schedule) @defer.inlineCallbacks - def _request(self, destination, method, path, - json=None, json_callback=None, - param_bytes=b"", - query=None, retry_on_dns_fail=True, - timeout=None, long_retries=False, - ignore_backoff=False, - backoff_on_404=False): + def _send_request( + self, + request, + retry_on_dns_fail=True, + timeout=None, + long_retries=False, + ignore_backoff=False, + backoff_on_404=False + ): """ - Creates and sends a request to the given server. + Sends a request to the given server. Args: - destination (str): The remote server to send the HTTP request to. - method (str): HTTP method - path (str): The HTTP path - json (dict or None): JSON to send in the body. - json_callback (func or None): A callback to generate the JSON. - query (dict or None): Query arguments. + request (MatrixFederationRequest): details of request to be sent + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server). 60s by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + backoff_on_404 (bool): Back off if we get a 404 Returns: Deferred: resolves with the http response object on success. - Fails with ``HTTPRequestException``: if we get an HTTP response + Fails with ``HttpResponseException``: if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -154,38 +249,32 @@ class MatrixFederationHttpClient(object): if ( self.hs.config.federation_domain_whitelist is not None and - destination not in self.hs.config.federation_domain_whitelist + request.destination not in self.hs.config.federation_domain_whitelist ): - raise FederationDeniedError(destination) + raise FederationDeniedError(request.destination) limiter = yield synapse.util.retryutils.get_retry_limiter( - destination, + request.destination, self.clock, self._store, backoff_on_404=backoff_on_404, ignore_backoff=ignore_backoff, ) - headers_dict = {} - path_bytes = path.encode("ascii") - if query: - query_bytes = encode_query_args(query) + method_bytes = request.method.encode("ascii") + destination_bytes = request.destination.encode("ascii") + path_bytes = request.path.encode("ascii") + if request.query: + query_bytes = encode_query_args(request.query) else: query_bytes = b"" headers_dict = { - "User-Agent": [self.version_string], - "Host": [destination], + b"User-Agent": [self.version_string_bytes], + b"Host": [destination_bytes], } with limiter: - url = self._create_url( - destination.encode("ascii"), path_bytes, param_bytes, query_bytes - ).decode('ascii') - - txn_id = "%s-O-%s" % (method, self._next_id) - self._next_id = (self._next_id + 1) % (MAXINT - 1) - # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) if long_retries: @@ -193,49 +282,56 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - http_url = urllib.parse.urlunparse( - (b"", b"", path_bytes, param_bytes, query_bytes, b"") - ).decode('ascii') + url_bytes = urllib.parse.urlunparse(( + b"matrix", destination_bytes, + path_bytes, None, query_bytes, b"", + )) + url_str = url_bytes.decode('ascii') + + url_to_sign_bytes = urllib.parse.urlunparse(( + b"", b"", + path_bytes, None, query_bytes, b"", + )) - log_result = None while True: try: - if json_callback: - json = json_callback() - + json = request.get_json() if json: - data = encode_canonical_json(json) - headers_dict["Content-Type"] = ["application/json"] + headers_dict[b"Content-Type"] = [b"application/json"] self.sign_request( - destination, method, http_url, headers_dict, json + destination_bytes, method_bytes, url_to_sign_bytes, + headers_dict, json, + ) + data = encode_canonical_json(json) + producer = FileBodyProducer( + BytesIO(data), + cooperator=self._cooperator, ) else: - data = None - self.sign_request(destination, method, http_url, headers_dict) + producer = None + self.sign_request( + destination_bytes, method_bytes, url_to_sign_bytes, + headers_dict, + ) - outbound_logger.info( + logger.info( "{%s} [%s] Sending request: %s %s", - txn_id, destination, method, url + request.txn_id, request.destination, request.method, + url_str, ) - request_deferred = treq.request( - method, - url, + # we don't want all the fancy cookie and redirect handling that + # treq.request gives: just use the raw Agent. + request_deferred = self.agent.request( + method_bytes, + url_bytes, headers=Headers(headers_dict), - data=data, - agent=self.agent, - reactor=self.hs.get_reactor(), - unbuffered=True + bodyProducer=producer, ) - request_deferred.addTimeout(_sec_timeout, self.hs.get_reactor()) - # Sometimes the timeout above doesn't work, so lets hack yet - # another layer of timeouts in in the vain hope that at some - # point the world made sense and this really really really - # should work. - request_deferred = timeout_no_seriously( + request_deferred = timeout_deferred( request_deferred, - timeout=_sec_timeout * 2, + timeout=_sec_timeout, reactor=self.hs.get_reactor(), ) @@ -244,33 +340,19 @@ class MatrixFederationHttpClient(object): request_deferred, ) - log_result = "%d %s" % ( - response.code, - response.phrase.decode('ascii', errors='replace'), - ) break except Exception as e: - if not retry_on_dns_fail and isinstance(e, DNSLookupError): - logger.warn( - "DNS Lookup failed to %s with %s", - destination, - e - ) - log_result = "DNS Lookup failed to %s with %s" % ( - destination, e - ) - raise - logger.warn( - "{%s} Sending request failed to %s: %s %s: %s", - txn_id, - destination, - method, - url, + "{%s} [%s] Request failed: %s %s: %s", + request.txn_id, + request.destination, + request.method, + url_str, _flatten_response_never_received(e), ) - log_result = _flatten_response_never_received(e) + if not retry_on_dns_fail and isinstance(e, DNSLookupError): + raise if retries_left and not timeout: if long_retries: @@ -283,33 +365,37 @@ class MatrixFederationHttpClient(object): delay *= random.uniform(0.8, 1.4) logger.debug( - "{%s} Waiting %s before sending to %s...", - txn_id, + "{%s} [%s] Waiting %ss before re-sending...", + request.txn_id, + request.destination, delay, - destination ) yield self.clock.sleep(delay) retries_left -= 1 else: raise - finally: - outbound_logger.info( - "{%s} [%s] Result: %s", - txn_id, - destination, - log_result, - ) + + logger.info( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + ) if 200 <= response.code < 300: pass else: # :'( # Update transactions table? - with logcontext.PreserveLoggingContext(): - d = treq.content(response) - d.addTimeout(_sec_timeout, self.hs.get_reactor()) - body = yield make_deferred_yieldable(d) + d = treq.content(response) + d = timeout_deferred( + d, + timeout=_sec_timeout, + reactor=self.hs.get_reactor(), + ) + body = yield make_deferred_yieldable(d) raise HttpResponseException( response.code, response.phrase, body ) @@ -326,8 +412,9 @@ class MatrixFederationHttpClient(object): destination_is must be non-None. method (bytes): The HTTP method of the request url_bytes (bytes): The URI path of the request - headers_dict (dict): Dictionary of request headers to append to - content (bytes): The body of the request + headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to + append to + content (object): The body of the request destination_is (bytes): As 'destination', but if the destination is an identity server @@ -393,7 +480,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -403,29 +490,26 @@ class MatrixFederationHttpClient(object): is not on our federation whitelist """ - if not json_data_callback: - json_data_callback = lambda: data - - response = yield self._request( - destination, - "PUT", - path, - json_callback=json_data_callback, + request = MatrixFederationRequest( + method="PUT", + destination=destination, + path=path, query=args, + json_callback=json_data_callback, + json=data, + ) + + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, backoff_on_404=backoff_on_404, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - d = treq.json_content(response) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) - body = yield make_deferred_yieldable(d) + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) defer.returnValue(body) @defer.inlineCallbacks @@ -450,7 +534,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -459,31 +543,30 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - response = yield self._request( - destination, - "POST", - path, + + request = MatrixFederationRequest( + method="POST", + destination=destination, + path=path, query=args, json=data, + ) + + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - d = treq.json_content(response) - if timeout: - _sec_timeout = timeout / 1000 - else: - _sec_timeout = self.default_timeout - - d.addTimeout(_sec_timeout, self.hs.get_reactor()) - body = yield make_deferred_yieldable(d) + if timeout: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = yield _handle_json_response( + self.hs.get_reactor(), _sec_timeout, request, response, + ) defer.returnValue(body) @defer.inlineCallbacks @@ -506,7 +589,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -519,25 +602,23 @@ class MatrixFederationHttpClient(object): logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - response = yield self._request( - destination, - "GET", - path, + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, query=args, + ) + + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - d = treq.json_content(response) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) - body = yield make_deferred_yieldable(d) - + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) defer.returnValue(body) @defer.inlineCallbacks @@ -559,7 +640,7 @@ class MatrixFederationHttpClient(object): Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. - Fails with ``HTTPRequestException`` if we get an HTTP response + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready @@ -568,25 +649,23 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - response = yield self._request( - destination, - "DELETE", - path, + request = MatrixFederationRequest( + method="DELETE", + destination=destination, + path=path, query=args, + ) + + response = yield self._send_request( + request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - if 200 <= response.code < 300: - # We need to update the transactions table to say it was sent? - check_content_type_is_json(response.headers) - - with logcontext.PreserveLoggingContext(): - d = treq.json_content(response) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) - body = yield make_deferred_yieldable(d) - + body = yield _handle_json_response( + self.hs.get_reactor(), self.default_timeout, request, response, + ) defer.returnValue(body) @defer.inlineCallbacks @@ -605,7 +684,7 @@ class MatrixFederationHttpClient(object): Deferred: resolves with an (int,dict) tuple of the file length and a dict of the response headers. - Fails with ``HTTPRequestException`` if we get an HTTP response code + Fails with ``HttpResponseException`` if we get an HTTP response code >= 300 Fails with ``NotRetryingDestination`` if we are not yet ready @@ -614,11 +693,15 @@ class MatrixFederationHttpClient(object): Fails with ``FederationDeniedError`` if this destination is not on our federation whitelist """ - response = yield self._request( - destination, - "GET", - path, + request = MatrixFederationRequest( + method="GET", + destination=destination, + path=path, query=args, + ) + + response = yield self._send_request( + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -626,14 +709,25 @@ class MatrixFederationHttpClient(object): headers = dict(response.headers.getAllRawHeaders()) try: - with logcontext.PreserveLoggingContext(): - d = _readBodyToFile(response, output_stream, max_size) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) - length = yield make_deferred_yieldable(d) - except Exception: - logger.exception("Failed to download body") + d = _readBodyToFile(response, output_stream, max_size) + d.addTimeout(self.default_timeout, self.hs.get_reactor()) + length = yield make_deferred_yieldable(d) + except Exception as e: + logger.warn( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) raise - + logger.info( + "{%s} [%s] Completed: %d %s [%d bytes]", + request.txn_id, + request.destination, + response.code, + response.phrase.decode('ascii', errors='replace'), + length, + ) defer.returnValue((length, headers)) diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 72c2654678..62045a918b 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -39,7 +39,8 @@ outgoing_responses_counter = Counter( ) response_timer = Histogram( - "synapse_http_server_response_time_seconds", "sec", + "synapse_http_server_response_time_seconds", + "sec", ["method", "servlet", "tag", "code"], ) @@ -79,15 +80,11 @@ response_size = Counter( # than when the response was written. in_flight_requests_ru_utime = Counter( - "synapse_http_server_in_flight_requests_ru_utime_seconds", - "", - ["method", "servlet"], + "synapse_http_server_in_flight_requests_ru_utime_seconds", "", ["method", "servlet"] ) in_flight_requests_ru_stime = Counter( - "synapse_http_server_in_flight_requests_ru_stime_seconds", - "", - ["method", "servlet"], + "synapse_http_server_in_flight_requests_ru_stime_seconds", "", ["method", "servlet"] ) in_flight_requests_db_txn_count = Counter( @@ -134,7 +131,7 @@ def _get_in_flight_counts(): # type counts = {} for rm in reqs: - key = (rm.method, rm.name,) + key = (rm.method, rm.name) counts[key] = counts.get(key, 0) + 1 return counts @@ -162,7 +159,7 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.add(self) - def stop(self, time_sec, request): + def stop(self, time_sec, response_code, sent_bytes): with _in_flight_requests_lock: _in_flight_requests.discard(self) @@ -175,39 +172,40 @@ class RequestMetrics(object): if context != self.start_context: logger.warn( "Context have unexpectedly changed %r, %r", - context, self.start_context + context, + self.start_context, ) return - response_code = str(request.code) + response_code = str(response_code) - outgoing_responses_counter.labels(request.method, response_code).inc() + outgoing_responses_counter.labels(self.method, response_code).inc() - response_count.labels(request.method, self.name, tag).inc() + response_count.labels(self.method, self.name, tag).inc() - response_timer.labels(request.method, self.name, tag, response_code).observe( + response_timer.labels(self.method, self.name, tag, response_code).observe( time_sec - self.start ) resource_usage = context.get_resource_usage() - response_ru_utime.labels(request.method, self.name, tag).inc( - resource_usage.ru_utime, + response_ru_utime.labels(self.method, self.name, tag).inc( + resource_usage.ru_utime ) - response_ru_stime.labels(request.method, self.name, tag).inc( - resource_usage.ru_stime, + response_ru_stime.labels(self.method, self.name, tag).inc( + resource_usage.ru_stime ) - response_db_txn_count.labels(request.method, self.name, tag).inc( + response_db_txn_count.labels(self.method, self.name, tag).inc( resource_usage.db_txn_count ) - response_db_txn_duration.labels(request.method, self.name, tag).inc( + response_db_txn_duration.labels(self.method, self.name, tag).inc( resource_usage.db_txn_duration_sec ) - response_db_sched_duration.labels(request.method, self.name, tag).inc( + response_db_sched_duration.labels(self.method, self.name, tag).inc( resource_usage.db_sched_duration_sec ) - response_size.labels(request.method, self.name, tag).inc(request.sentLength) + response_size.labels(self.method, self.name, tag).inc(sent_bytes) # We always call this at the end to ensure that we update the metrics # regardless of whether a call to /metrics while the request was in @@ -222,8 +220,15 @@ class RequestMetrics(object): diff = new_stats - self._request_stats self._request_stats = new_stats - in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime) - in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime) + # max() is used since rapid use of ru_stime/ru_utime can end up with the + # count going backwards due to NTP, time smearing, fine-grained + # correction, or floating points. Who knows, really? + in_flight_requests_ru_utime.labels(self.method, self.name).inc( + max(diff.ru_utime, 0) + ) + in_flight_requests_ru_stime.labels(self.method, self.name).inc( + max(diff.ru_stime, 0) + ) in_flight_requests_db_txn_count.labels(self.method, self.name).inc( diff.db_txn_count diff --git a/synapse/http/server.py b/synapse/http/server.py index 2d5c23e673..b4b25cab19 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -84,10 +84,21 @@ def wrap_json_request_handler(h): logger.info( "%s SynapseError: %s - %s", request, code, e.msg ) - respond_with_json( - request, code, e.error_dict(), send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, code, e.error_dict(), send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) except Exception: # failure.Failure() fishes the original Failure out @@ -100,16 +111,26 @@ def wrap_json_request_handler(h): request, f.getTraceback().rstrip(), ) - respond_with_json( - request, - 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) + # Only respond with an error response if we haven't already started + # writing, otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, + 500, + { + "error": "Internal server error", + "errcode": Codes.UNKNOWN, + }, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) return wrap_async_request_handler(wrapped_request_handler) diff --git a/synapse/http/site.py b/synapse/http/site.py index 2191de9836..e508c0bd4f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -75,14 +75,14 @@ class SynapseRequest(Request): return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( self.__class__.__name__, id(self), - self.method, + self.get_method(), self.get_redacted_uri(), - self.clientproto, + self.clientproto.decode('ascii', errors='replace'), self.site.site_tag, ) def get_request_id(self): - return "%s-%i" % (self.method.decode('ascii'), self.request_seq) + return "%s-%i" % (self.get_method(), self.request_seq) def get_redacted_uri(self): uri = self.uri @@ -90,6 +90,21 @@ class SynapseRequest(Request): uri = self.uri.decode('ascii') return redact_uri(uri) + def get_method(self): + """Gets the method associated with the request (or placeholder if not + method has yet been received). + + Note: This is necessary as the placeholder value in twisted is str + rather than bytes, so we need to sanitise `self.method`. + + Returns: + str + """ + method = self.method + if isinstance(method, bytes): + method = self.method.decode('ascii') + return method + def get_user_agent(self): return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] @@ -119,7 +134,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.method.decode('ascii'), + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager @@ -207,14 +222,14 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.method.decode('ascii'), + self.start_time, name=servlet_name, method=self.get_method(), ) self.site.access_logger.info( "%s - %s - Received request: %s %s", self.getClientIP(), self.site.site_tag, - self.method.decode('ascii'), + self.get_method(), self.get_redacted_uri() ) @@ -280,7 +295,7 @@ class SynapseRequest(Request): int(usage.db_txn_count), self.sentLength, code, - self.method.decode('ascii'), + self.get_method(), self.get_redacted_uri(), self.clientproto.decode('ascii', errors='replace'), user_agent, @@ -288,7 +303,7 @@ class SynapseRequest(Request): ) try: - self.request_metrics.stop(self.finish_time, self) + self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: logger.warn("Failed to stop metrics: %r", e) @@ -308,7 +323,7 @@ class XForwardedForRequest(SynapseRequest): C{b"-"}. """ return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() + b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') class SynapseRequestFactory(object): diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 167167be0a..037f1c490e 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading import six @@ -23,6 +24,9 @@ from twisted.internet import defer from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +logger = logging.getLogger(__name__) + + _background_process_start_count = Counter( "synapse_background_process_start_count", "Number of background processes started", @@ -97,9 +101,13 @@ class _Collector(object): labels=["name"], ) - # We copy the dict so that it doesn't change from underneath us + # We copy the dict so that it doesn't change from underneath us. + # We also copy the process lists as that can also change with _bg_metrics_lock: - _background_processes_copy = dict(_background_processes) + _background_processes_copy = { + k: list(v) + for k, v in six.iteritems(_background_processes) + } for desc, processes in six.iteritems(_background_processes_copy): background_process_in_flight_count.add_metric( @@ -191,6 +199,8 @@ def run_as_background_process(desc, func, *args, **kwargs): try: yield func(*args, **kwargs) + except Exception: + logger.exception("Background process '%s' threw an exception", desc) finally: proc.update_metrics() diff --git a/synapse/notifier.py b/synapse/notifier.py index 82f391481c..de02b1017e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -24,13 +24,10 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import StreamToken -from synapse.util.async_helpers import ( - DeferredTimeoutError, - ObservableDeferred, - add_timeout_to_deferred, -) -from synapse.util.logcontext import PreserveLoggingContext, run_in_background +from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -189,9 +186,9 @@ class Notifier(object): def count_listeners(): all_user_streams = set() - for x in self.room_to_user_streams.values(): + for x in list(self.room_to_user_streams.values()): all_user_streams |= x - for x in self.user_to_user_stream.values(): + for x in list(self.user_to_user_stream.values()): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) @@ -199,7 +196,7 @@ class Notifier(object): LaterGauge( "synapse_notifier_rooms", "", [], - lambda: count(bool, self.room_to_user_streams.values()), + lambda: count(bool, list(self.room_to_user_streams.values())), ) LaterGauge( "synapse_notifier_users", "", [], @@ -252,7 +249,10 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - run_in_background(self._notify_app_services, room_stream_id) + run_as_background_process( + "notify_app_services", + self._notify_app_services, room_stream_id, + ) if self.federation_sender: self.federation_sender.notify_new_events(room_stream_id) @@ -337,7 +337,7 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) - add_timeout_to_deferred( + listener.deferred = timeout_deferred( listener.deferred, (end_time - now) / 1000., self.hs.get_reactor(), @@ -354,7 +354,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break @@ -559,15 +559,16 @@ class Notifier(object): if end_time <= now: break - add_timeout_to_deferred( - listener.deferred.addTimeout, - (end_time - now) / 1000., - self.hs.get_reactor(), + listener.deferred = timeout_deferred( + listener.deferred, + timeout=(end_time - now) / 1000., + reactor=self.hs.get_reactor(), ) + try: with PreserveLoggingContext(): yield listener.deferred - except DeferredTimeoutError: + except defer.TimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index d746371420..50e1007d84 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -18,8 +18,7 @@ import logging from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure +from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) @@ -71,18 +70,11 @@ class EmailPusher(object): # See httppusher self.max_stream_ordering = None - self.processing = False + self._is_processing = False - @defer.inlineCallbacks def on_started(self): if self.mailer is not None: - try: - self.throttle_params = yield self.store.get_throttle_params_by_room( - self.pusher_id - ) - yield self._process() - except Exception: - logger.exception("Error starting email pusher") + self._start_processing() def on_stop(self): if self.timed_call: @@ -92,43 +84,55 @@ class EmailPusher(object): pass self.timed_call = None - @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) - yield self._process() + if self.max_stream_ordering: + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + else: + self.max_stream_ordering = max_stream_ordering + self._start_processing() def on_new_receipts(self, min_stream_id, max_stream_id): # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the # timer fire - return defer.succeed(None) + pass - @defer.inlineCallbacks def on_timer(self): self.timed_call = None - yield self._process() + self._start_processing() + + def _start_processing(self): + if self._is_processing: + return + + run_as_background_process("emailpush.process", self._process) @defer.inlineCallbacks def _process(self): - if self.processing: - return + # we should never get here if we are already processing + assert not self._is_processing + + try: + self._is_processing = True + + if self.throttle_params is None: + # this is our first loop: load up the throttle params + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) - with LoggingContext("emailpush._process"): - with Measure(self.clock, "emailpush._process"): + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering try: - self.processing = True - # if the max ordering changes while we're running _unsafe_process, - # call it again, and so on until we've caught up. - while True: - starting_max_ordering = self.max_stream_ordering - try: - yield self._unsafe_process() - except Exception: - logger.exception("Exception processing notifs") - if self.max_stream_ordering == starting_max_ordering: - break - finally: - self.processing = False + yield self._unsafe_process() + except Exception: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self._is_processing = False @defer.inlineCallbacks def _unsafe_process(self): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 48abd5e4d6..6bd703632d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -22,9 +22,8 @@ from prometheus_client import Counter from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException -from synapse.util.logcontext import LoggingContext -from synapse.util.metrics import Measure from . import push_rule_evaluator, push_tools @@ -61,7 +60,7 @@ class HttpPusher(object): self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusherdict['failing_since'] self.timed_call = None - self.processing = False + self._is_processing = False # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we @@ -92,34 +91,27 @@ class HttpPusher(object): self.data_minus_url.update(self.data) del self.data_minus_url['url'] - @defer.inlineCallbacks def on_started(self): - try: - yield self._process() - except Exception: - logger.exception("Error starting http pusher") + self._start_processing() - @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) - yield self._process() + self._start_processing() - @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id): # Note that the min here shouldn't be relied upon to be accurate. # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - with LoggingContext("push.on_new_receipts"): - with Measure(self.clock, "push.on_new_receipts"): - badge = yield push_tools.get_badge_count( - self.hs.get_datastore(), self.user_id - ) - yield self._send_badge(badge) + run_as_background_process("http_pusher.on_new_receipts", self._update_badge) @defer.inlineCallbacks + def _update_badge(self): + badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + yield self._send_badge(badge) + def on_timer(self): - yield self._process() + self._start_processing() def on_stop(self): if self.timed_call: @@ -129,27 +121,31 @@ class HttpPusher(object): pass self.timed_call = None + def _start_processing(self): + if self._is_processing: + return + + run_as_background_process("httppush.process", self._process) + @defer.inlineCallbacks def _process(self): - if self.processing: - return + # we should never get here if we are already processing + assert not self._is_processing - with LoggingContext("push._process"): - with Measure(self.clock, "push._process"): + try: + self._is_processing = True + # if the max ordering changes while we're running _unsafe_process, + # call it again, and so on until we've caught up. + while True: + starting_max_ordering = self.max_stream_ordering try: - self.processing = True - # if the max ordering changes while we're running _unsafe_process, - # call it again, and so on until we've caught up. - while True: - starting_max_ordering = self.max_stream_ordering - try: - yield self._unsafe_process() - except Exception: - logger.exception("Exception processing notifs") - if self.max_stream_ordering == starting_max_ordering: - break - finally: - self.processing = False + yield self._unsafe_process() + except Exception: + logger.exception("Exception processing notifs") + if self.max_stream_ordering == starting_max_ordering: + break + finally: + self._is_processing = False @defer.inlineCallbacks def _unsafe_process(self): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b78ce10396..ebcb93bfc7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -26,7 +26,6 @@ import bleach import jinja2 from twisted.internet import defer -from twisted.mail.smtp import sendmail from synapse.api.constants import EventTypes from synapse.api.errors import StoreError @@ -85,6 +84,7 @@ class Mailer(object): self.notif_template_html = notif_template_html self.notif_template_text = notif_template_text + self.sendmail = self.hs.get_sendmail() self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() @@ -191,11 +191,11 @@ class Mailer(object): multipart_msg.attach(html_part) logger.info("Sending email push notification to %s" % email_address) - # logger.debug(html_text) - yield sendmail( + yield self.sendmail( self.hs.config.email_smtp_host, - raw_from, raw_to, multipart_msg.as_string(), + raw_from, raw_to, multipart_msg.as_string().encode('utf8'), + reactor=self.hs.get_reactor(), port=self.hs.config.email_smtp_port, requireAuthentication=self.hs.config.email_smtp_user is not None, username=self.hs.config.email_smtp_user, @@ -333,7 +333,7 @@ class Mailer(object): notif_events, user_id, reason): if len(notifs_by_room) == 1: # Only one room has new stuff - room_id = notifs_by_room.keys()[0] + room_id = list(notifs_by_room.keys())[0] # If the room has some kind of name, use it, but we don't # want the generated-from-names one here otherwise we'll @@ -441,7 +441,7 @@ class Mailer(object): def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: - base_url = self.hs.config.email_riot_base_url + base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS base_url = "https://vector.im/beta/#/room" @@ -526,8 +526,7 @@ def load_jinja2_templates(config): Returns: (notif_template_html, notif_template_text) """ - logger.info("loading jinja2") - + logger.info("loading email templates from '%s'", config.email_template_dir) loader = jinja2.FileSystemLoader(config.email_template_dir) env = jinja2.Environment(loader=loader) env.filters["format_ts"] = format_ts_filter diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 9f7d5ef217..5a4e73ccd6 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -20,24 +20,39 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push.pusher import PusherFactory -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) class PusherPool: + """ + The pusher pool. This is responsible for dispatching notifications of new events to + the http and email pushers. + + It provides three methods which are designed to be called by the rest of the + application: `start`, `on_new_notifications`, and `on_new_receipts`: each of these + delegates to each of the relevant pushers. + + Note that it is expected that each pusher will have its own 'processing' loop which + will send out the notifications in the background, rather than blocking until the + notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and + Pusher.on_new_receipts are not expected to return deferreds. + """ def __init__(self, _hs): self.hs = _hs self.pusher_factory = PusherFactory(_hs) - self.start_pushers = _hs.config.start_pushers + self._should_start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.pushers = {} - @defer.inlineCallbacks def start(self): - pushers = yield self.store.get_all_pushers() - self._start_pushers(pushers) + """Starts the pushers off in a background process. + """ + if not self._should_start_pushers: + logger.info("Not starting pushers because they are disabled in the config") + return + run_as_background_process("start_pushers", self._start_pushers) @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, @@ -86,7 +101,7 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - yield self._refresh_pusher(app_id, pushkey, user_id) + yield self.start_pusher_by_id(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -123,45 +138,23 @@ class PusherPool: p['app_id'], p['pushkey'], p['user_name'], ) - def on_new_notifications(self, min_stream_id, max_stream_id): - run_as_background_process( - "on_new_notifications", - self._on_new_notifications, min_stream_id, max_stream_id, - ) - @defer.inlineCallbacks - def _on_new_notifications(self, min_stream_id, max_stream_id): + def on_new_notifications(self, min_stream_id, max_stream_id): try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) - deferreds = [] - for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - deferreds.append( - run_in_background( - p.on_new_notifications, - min_stream_id, max_stream_id, - ) - ) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True), - ) + p.on_new_notifications(min_stream_id, max_stream_id) + except Exception: logger.exception("Exception in pusher on_new_notifications") - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): - run_as_background_process( - "on_new_receipts", - self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, - ) - @defer.inlineCallbacks - def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive @@ -171,26 +164,20 @@ class PusherPool: # This returns a tuple, user_id is at index 3 users_affected = set([r[3] for r in updated_receipts]) - deferreds = [] - for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - deferreds.append( - run_in_background( - p.on_new_receipts, - min_stream_id, max_stream_id, - ) - ) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True), - ) + p.on_new_receipts(min_stream_id, max_stream_id) + except Exception: logger.exception("Exception in pusher on_new_receipts") @defer.inlineCallbacks - def _refresh_pusher(self, app_id, pushkey, user_id): + def start_pusher_by_id(self, app_id, pushkey, user_id): + """Look up the details for the given pusher, and start it""" + if not self._should_start_pushers: + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( app_id, pushkey ) @@ -201,33 +188,49 @@ class PusherPool: p = r if p: + self._start_pusher(p) - self._start_pushers([p]) + @defer.inlineCallbacks + def _start_pushers(self): + """Start all the pushers - def _start_pushers(self, pushers): - if not self.start_pushers: - logger.info("Not starting pushers because they are disabled in the config") - return + Returns: + Deferred + """ + pushers = yield self.store.get_all_pushers() logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: - try: - p = self.pusher_factory.create_pusher(pusherdict) - except Exception: - logger.exception("Couldn't start a pusher: caught Exception") - continue - if p: - appid_pushkey = "%s:%s" % ( - pusherdict['app_id'], - pusherdict['pushkey'], - ) - byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + self._start_pusher(pusherdict) + logger.info("Started pushers") - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p - run_in_background(p.on_started) + def _start_pusher(self, pusherdict): + """Start the given pusher - logger.info("Started pushers") + Args: + pusherdict (dict): + + Returns: + None + """ + try: + p = self.pusher_factory.create_pusher(pusherdict) + except Exception: + logger.exception("Couldn't start a pusher: caught Exception") + return + + if not p: + return + + appid_pushkey = "%s:%s" % ( + pusherdict['app_id'], + pusherdict['pushkey'], + ) + byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + p.on_started() @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 0d8de600cf..ca62ee7637 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -33,32 +33,38 @@ logger = logging.getLogger(__name__) # [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], - "frozendict>=0.4": ["frozendict"], + "frozendict>=1": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], - "service_identity>=1.0.0": ["service_identity>=1.0.0"], + "service_identity>=16.0.0": ["service_identity>=16.0.0"], "Twisted>=17.1.0": ["twisted>=17.1.0"], "treq>=15.1": ["treq>=15.1"], # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"], - "pyyaml": ["yaml"], - "pyasn1": ["pyasn1"], - "daemonize": ["daemonize"], - "bcrypt": ["bcrypt>=3.1.0"], - "pillow": ["PIL"], - "pydenticon": ["pydenticon"], - "sortedcontainers": ["sortedcontainers"], - "pysaml2>=3.0.0": ["saml2>=3.0.0"], - "pymacaroons-pynacl": ["pymacaroons"], - "msgpack-python>=0.3.0": ["msgpack"], + "pyyaml>=3.11": ["yaml"], + "pyasn1>=0.1.9": ["pyasn1"], + "pyasn1-modules>=0.0.7": ["pyasn1_modules"], + "daemonize>=2.3.1": ["daemonize"], + "bcrypt>=3.1.0": ["bcrypt>=3.1.0"], + "pillow>=3.1.2": ["PIL"], + "sortedcontainers>=1.4.4": ["sortedcontainers"], + "psutil>=2.0.0": ["psutil>=2.0.0"], + "pysaml2>=3.0.0": ["saml2"], + "pymacaroons-pynacl>=0.9.3": ["pymacaroons"], + "msgpack-python>=0.4.2": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], - "six": ["six"], - "prometheus_client": ["prometheus_client"], - "attrs": ["attr"], + "six>=1.10": ["six"], + + # prometheus_client 0.4.0 changed the format of counter metrics + # (cf https://github.com/matrix-org/synapse/issues/4001) + "prometheus_client>=0.0.18,<0.4.0": ["prometheus_client"], + + # we use attr.s(slots), which arrived in 16.0.0 + "attrs>=16.0.0": ["attr>=16.0.0"], "netaddr>=0.7.18": ["netaddr"], } @@ -73,12 +79,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": { "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], }, - "psutil": { - "psutil>=2.0.0": ["psutil>=2.0.0"], - }, - "affinity": { - "affinity": ["affinity"], - }, "postgres": { "psycopg2>=2.6": ["psycopg2"] } diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 3f7be74e02..2d81d49e9a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,6 +15,8 @@ import logging +import six + from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine @@ -23,6 +25,13 @@ from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) +def __func__(inp): + if six.PY3: + return inp + else: + return inp.__func__ + + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(db_conn, hs) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 87eaa53004..4f19fd35aa 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -17,7 +17,7 @@ from synapse.storage import DataStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -43,11 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ - get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ - get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ - delete_messages_for_device = DataStore.delete_messages_for_device.__func__ - delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ + get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token) + get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device) + get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote) + delete_messages_for_device = __func__(DataStore.delete_messages_for_device) + delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote) def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 21b8c468fa..ec2fd561cc 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,23 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - from synapse.storage import DataStore from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker -def __func__(inp): - if six.PY3: - return inp - else: - return inp.__func__ - - class SlavedDeviceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 5777f07c8d..e933b170bb 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -16,7 +16,7 @@ from synapse.storage import DataStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -33,9 +33,9 @@ class SlavedGroupServerStore(BaseSlavedStore): "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), ) - get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ - get_group_stream_token = DataStore.get_group_stream_token.__func__ - get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ + get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) + get_group_stream_token = __func__(DataStore.get_group_stream_token) + get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user) def stream_positions(self): result = super(SlavedGroupServerStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 05ed168463..8032f53fec 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -16,7 +16,7 @@ from synapse.storage import DataStore from synapse.storage.keys import KeyStore -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ class SlavedKeyStore(BaseSlavedStore): @@ -24,11 +24,11 @@ class SlavedKeyStore(BaseSlavedStore): "_get_server_verify_key" ] - get_server_verify_keys = DataStore.get_server_verify_keys.__func__ - store_server_verify_key = DataStore.store_server_verify_key.__func__ + get_server_verify_keys = __func__(DataStore.get_server_verify_keys) + store_server_verify_key = __func__(DataStore.store_server_verify_key) - get_server_certificate = DataStore.get_server_certificate.__func__ - store_server_certificate = DataStore.store_server_certificate.__func__ + get_server_certificate = __func__(DataStore.get_server_certificate) + store_server_certificate = __func__(DataStore.store_server_certificate) - get_server_keys_json = DataStore.get_server_keys_json.__func__ - store_server_keys_json = DataStore.store_server_keys_json.__func__ + get_server_keys_json = __func__(DataStore.get_server_keys_json) + store_server_keys_json = __func__(DataStore.store_server_keys_json) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 80b744082a..92447b00d4 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -17,7 +17,7 @@ from synapse.storage import DataStore from synapse.storage.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore +from ._base import BaseSlavedStore, __func__ from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +34,8 @@ class SlavedPresenceStore(BaseSlavedStore): "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() ) - _get_active_presence = DataStore._get_active_presence.__func__ - take_presence_startup_info = DataStore.take_presence_startup_info.__func__ + _get_active_presence = __func__(DataStore._get_active_presence) + take_presence_startup_info = __func__(DataStore.take_presence_startup_info) _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cbe9645817..586dddb40b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -106,7 +106,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ - logger.info("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s -> %s", stream_name, token) return self.store.process_replication_rows(stream_name, token, rows) def on_position(self, stream_name, token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5dc7b3fffc..0b3fe6cbf5 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -656,7 +656,7 @@ tcp_inbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k[0], p.name,): count + (k, p.name,): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -667,7 +667,7 @@ tcp_outbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k[0], p.name,): count + (k, p.name,): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/res/templates/mail-Vector.css b/synapse/res/templates/mail-Vector.css new file mode 100644 index 0000000000..6a3e36eda1 --- /dev/null +++ b/synapse/res/templates/mail-Vector.css @@ -0,0 +1,7 @@ +.header { + border-bottom: 4px solid #e4f7ed ! important; +} + +.notif_link a, .footer a { + color: #76CFA6 ! important; +} diff --git a/synapse/res/templates/mail.css b/synapse/res/templates/mail.css new file mode 100644 index 0000000000..5ab3e1b06d --- /dev/null +++ b/synapse/res/templates/mail.css @@ -0,0 +1,156 @@ +body { + margin: 0px; +} + +pre, code { + word-break: break-word; + white-space: pre-wrap; +} + +#page { + font-family: 'Open Sans', Helvetica, Arial, Sans-Serif; + font-color: #454545; + font-size: 12pt; + width: 100%; + padding: 20px; +} + +#inner { + width: 640px; +} + +.header { + width: 100%; + height: 87px; + color: #454545; + border-bottom: 4px solid #e5e5e5; +} + +.logo { + text-align: right; + margin-left: 20px; +} + +.salutation { + padding-top: 10px; + font-weight: bold; +} + +.summarytext { +} + +.room { + width: 100%; + color: #454545; + border-bottom: 1px solid #e5e5e5; +} + +.room_header td { + padding-top: 38px; + padding-bottom: 10px; + border-bottom: 1px solid #e5e5e5; +} + +.room_name { + vertical-align: middle; + font-size: 18px; + font-weight: bold; +} + +.room_header h2 { + margin-top: 0px; + margin-left: 75px; + font-size: 20px; +} + +.room_avatar { + width: 56px; + line-height: 0px; + text-align: center; + vertical-align: middle; +} + +.room_avatar img { + width: 48px; + height: 48px; + object-fit: cover; + border-radius: 24px; +} + +.notif { + border-bottom: 1px solid #e5e5e5; + margin-top: 16px; + padding-bottom: 16px; +} + +.historical_message .sender_avatar { + opacity: 0.3; +} + +/* spell out opacity and historical_message class names for Outlook aka Word */ +.historical_message .sender_name { + color: #e3e3e3; +} + +.historical_message .message_time { + color: #e3e3e3; +} + +.historical_message .message_body { + color: #c7c7c7; +} + +.historical_message td, +.message td { + padding-top: 10px; +} + +.sender_avatar { + width: 56px; + text-align: center; + vertical-align: top; +} + +.sender_avatar img { + margin-top: -2px; + width: 32px; + height: 32px; + border-radius: 16px; +} + +.sender_name { + display: inline; + font-size: 13px; + color: #a2a2a2; +} + +.message_time { + text-align: right; + width: 100px; + font-size: 11px; + color: #a2a2a2; +} + +.message_body { +} + +.notif_link td { + padding-top: 10px; + padding-bottom: 10px; + font-weight: bold; +} + +.notif_link a, .footer a { + color: #454545; + text-decoration: none; +} + +.debug { + font-size: 10px; + color: #888; +} + +.footer { + margin-top: 20px; + text-align: center; +} \ No newline at end of file diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html new file mode 100644 index 0000000000..88b921ca9c --- /dev/null +++ b/synapse/res/templates/notif.html @@ -0,0 +1,45 @@ +{% for message in notif.messages %} + <tr class="{{ "historical_message" if message.is_historical else "message" }}"> + <td class="sender_avatar"> + {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + {% if message.sender_avatar_url %} + <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" /> + {% else %} + {% if message.sender_hash % 3 == 0 %} + <img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" /> + {% elif message.sender_hash % 3 == 1 %} + <img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" /> + {% else %} + <img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" /> + {% endif %} + {% endif %} + {% endif %} + </td> + <td class="message_contents"> + {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div> + {% endif %} + <div class="message_body"> + {% if message.msgtype == "m.text" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.emote" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.notice" %} + {{ message.body_text_html }} + {% elif message.msgtype == "m.image" %} + <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> + {% elif message.msgtype == "m.file" %} + <span class="filename">{{ message.body_text_plain }}</span> + {% endif %} + </div> + </td> + <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td> + </tr> +{% endfor %} +<tr class="notif_link"> + <td></td> + <td> + <a href="{{ notif.link }}">View {{ room.title }}</a> + </td> + <td></td> +</tr> diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt new file mode 100644 index 0000000000..a37bee9833 --- /dev/null +++ b/synapse/res/templates/notif.txt @@ -0,0 +1,16 @@ +{% for message in notif.messages %} +{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) +{% if message.msgtype == "m.text" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.emote" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.notice" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.image" %} +{{ message.body_text_plain }} +{% elif message.msgtype == "m.file" %} +{{ message.body_text_plain }} +{% endif %} +{% endfor %} + +View {{ room.title }} at {{ notif.link }} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html new file mode 100644 index 0000000000..fcdb3109fe --- /dev/null +++ b/synapse/res/templates/notif_mail.html @@ -0,0 +1,55 @@ +<!doctype html> +<html lang="en"> + <head> + <style type="text/css"> + {% include 'mail.css' without context %} + {% include "mail-%s.css" % app_name ignore missing without context %} + </style> + </head> + <body> + <table id="page"> + <tr> + <td> </td> + <td id="inner"> + <table class="header"> + <tr> + <td> + <div class="salutation">Hi {{ user_display_name }},</div> + <div class="summarytext">{{ summary_text }}</div> + </td> + <td class="logo"> + {% if app_name == "Riot" %} + <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> + {% elif app_name == "Vector" %} + <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% else %} + <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> + {% endif %} + </td> + </tr> + </table> + {% for room in rooms %} + {% include 'room.html' with context %} + {% endfor %} + <div class="footer"> + <a href="{{ unsubscribe_link }}">Unsubscribe</a> + <br/> + <br/> + <div class="debug"> + Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because + an event was received at {{ reason.received_at|format_ts("%c") }} + which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, + {% if reason.last_sent_ts %} + and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, + which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. + {% else %} + and we don't have a last time we sent a mail for this room. + {% endif %} + </div> + </div> + </td> + <td> </td> + </tr> + </table> + </body> +</html> diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt new file mode 100644 index 0000000000..24843042a5 --- /dev/null +++ b/synapse/res/templates/notif_mail.txt @@ -0,0 +1,10 @@ +Hi {{ user_display_name }}, + +{{ summary_text }} + +{% for room in rooms %} +{% include 'room.txt' with context %} +{% endfor %} + +You can disable these notifications at {{ unsubscribe_link }} + diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html new file mode 100644 index 0000000000..723c222d25 --- /dev/null +++ b/synapse/res/templates/room.html @@ -0,0 +1,33 @@ +<table class="room"> + <tr class="room_header"> + <td class="room_avatar"> + {% if room.avatar_url %} + <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" /> + {% else %} + {% if room.hash % 3 == 0 %} + <img alt="" src="https://vector.im/beta/img/76cfa6.png" /> + {% elif room.hash % 3 == 1 %} + <img alt="" src="https://vector.im/beta/img/50e2c2.png" /> + {% else %} + <img alt="" src="https://vector.im/beta/img/f4c371.png" /> + {% endif %} + {% endif %} + </td> + <td class="room_name" colspan="2"> + {{ room.title }} + </td> + </tr> + {% if room.invite %} + <tr> + <td></td> + <td> + <a href="{{ room.link }}">Join the conversation.</a> + </td> + <td></td> + </tr> + {% else %} + {% for notif in room.notifs %} + {% include 'notif.html' with context %} + {% endfor %} + {% endif %} +</table> diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt new file mode 100644 index 0000000000..84648c710e --- /dev/null +++ b/synapse/res/templates/room.txt @@ -0,0 +1,9 @@ +{{ room.title }} + +{% if room.invite %} + You've been invited, join at {{ room.link }} +{% else %} + {% for notif in room.notifs %} + {% include 'notif.txt' with context %} + {% endfor %} +{% endif %} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3418f06fd6..5f35c2d1be 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,8 @@ from synapse.rest.client.v2_alpha import ( receipts, register, report_event, + room_keys, + room_upgrade_rest_servlet, sendtodevice, sync, tags, @@ -102,6 +104,7 @@ class ClientRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) read_marker.register_servlets(hs, client_resource) + room_keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) @@ -114,3 +117,4 @@ class ClientRestResource(JsonResource): sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) + room_upgrade_rest_servlet.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 97733f3026..0220acf644 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -74,38 +74,11 @@ class ClientDirectoryServer(ClientV1RestServlet): if room is None: raise SynapseError(400, "Room does not exist") - dir_handler = self.handlers.directory_handler + requester = yield self.auth.get_user_by_req(request) - try: - # try to auth as a user - requester = yield self.auth.get_user_by_req(request) - try: - user_id = requester.user.to_string() - yield dir_handler.create_association( - user_id, room_alias, room_id, servers - ) - yield dir_handler.send_room_alias_update_event( - requester, - user_id, - room_id - ) - except SynapseError as e: - raise e - except Exception: - logger.exception("Failed to create association") - raise - except AuthError: - # try to auth as an application service - service = yield self.auth.get_appservice_by_req(request) - yield dir_handler.create_appservice_association( - service, room_alias, room_id, servers - ) - logger.info( - "Application service at %s created alias %s pointing to %s", - service.url, - room_alias.to_string(), - room_id - ) + yield self.handlers.directory_handler.create_association( + requester, room_alias, room_id, servers + ) defer.returnValue((200, {})) @@ -135,7 +108,7 @@ class ClientDirectoryServer(ClientV1RestServlet): room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( - requester, user.to_string(), room_alias + requester, room_alias ) logger.info( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 663934efd0..fcfe7857f6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -33,6 +33,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID @@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, - types=[(EventTypes.Member, None)], + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) chunk = [] diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index bd8b5f4afa..a8d8ed6590 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -68,6 +68,29 @@ function captchaDone() { </html> """ +TERMS_TEMPLATE = """ +<html> +<head> +<title>Authentication</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +</head> +<body> +<form id="registrationForm" method="post" action="%(myurl)s"> + <div> + <p> + Please click the button below if you agree to the + <a href="%(terms_url)s">privacy policy of this homeserver.</a> + </p> + <input type="hidden" name="session" value="%(session)s" /> + <input type="submit" value="Agree" /> + </div> +</form> +</body> +</html> +""" + SUCCESS_TEMPLATE = """ <html> <head> @@ -99,7 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() @@ -133,13 +156,34 @@ class AuthRestServlet(RestServlet): request.write(html_bytes) finish_request(request) defer.returnValue(None) + elif stagetype == LoginType.TERMS: + session = request.args['session'][0] + + html = TERMS_TEMPLATE % { + 'session': session, + 'terms_url': "%s/_matrix/consent?v=%s" % ( + self.hs.config.public_baseurl, + self.hs.config.user_consent_version, + ), + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS + ), + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + defer.returnValue(None) else: raise SynapseError(404, "Unknown auth stage type") @defer.inlineCallbacks def on_POST(self, request, stagetype): yield - if stagetype == "m.login.recaptcha": + if stagetype == LoginType.RECAPTCHA: if ('g-recaptcha-response' not in request.args or len(request.args['g-recaptcha-response'])) == 0: raise SynapseError(400, "No captcha response supplied") @@ -179,6 +223,41 @@ class AuthRestServlet(RestServlet): finish_request(request) defer.returnValue(None) + elif stagetype == LoginType.TERMS: + if ('session' not in request.args or + len(request.args['session'])) == 0: + raise SynapseError(400, "No session supplied") + + session = request.args['session'][0] + authdict = {'session': session} + + success = yield self.auth_handler.add_oob_auth( + LoginType.TERMS, + authdict, + self.hs.get_ip_from_request(request) + ) + + if success: + html = SUCCESS_TEMPLATE + else: + html = TERMS_TEMPLATE % { + 'session': session, + 'terms_url': "%s/_matrix/consent?v=%s" % ( + self.hs.config.public_baseurl, + self.hs.config.user_consent_version, + ), + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS + ), + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + defer.returnValue(None) else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 192f52e462..c5214330ad 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -359,6 +359,13 @@ class RegisterRestServlet(RestServlet): [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] ]) + # Append m.login.terms to all flows if we're requiring consent + if self.hs.config.block_events_without_consent_error is not None: + new_flows = [] + for flow in flows: + flow.append(LoginType.TERMS) + flows.extend(new_flows) + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) @@ -445,6 +452,12 @@ class RegisterRestServlet(RestServlet): params.get("bind_msisdn") ) + if auth_result and LoginType.TERMS in auth_result: + logger.info("%s has consented to the privacy policy" % registered_user_id) + yield self.store.user_set_consent_version( + registered_user_id, self.hs.config.user_consent_version, + ) + defer.returnValue((200, return_dict)) def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py new file mode 100644 index 0000000000..45b5817d8b --- /dev/null +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RoomKeysServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, session_id): + """ + Uploads one or more encrypted E2E room keys for backup purposes. + room_id: the ID of the room the keys are for (optional) + session_id: the ID for the E2E room keys for the room (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /room_keys/version API. + + Each session has: + * first_message_index: a numeric index indicating the oldest message + encrypted by this session. + * forwarded_count: how many times the uploading client claims this key + has been shared (forwarded) + * is_verified: whether the client that uploaded the keys claims they + were sent by a device which they've verified + * session_data: base64-encrypted data describing the session. + + Returns 200 OK on success with body {} + Returns 403 Forbidden if the version in question is not the most recently + created version (i.e. if this is an old client trying to write to a stale backup) + Returns 404 Not Found if the version in question doesn't exist + + The API is designed to be otherwise agnostic to the room_key encryption + algorithm being used. Sessions are merged with existing ones in the + backup using the heuristics: + * is_verified sessions always win over unverified sessions + * older first_message_index always win over newer sessions + * lower forwarded_count always wins over higher forwarded_count + + We trust the clients not to lie and corrupt their own backups. + It also means that if your access_token is stolen, the attacker could + delete your backup. + + POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + Content-Type: application/json + + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + Content-Type: application/json + + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + POST /room_keys/keys?version=1 HTTP/1.1 + Content-Type: application/json + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + version = parse_string(request, "version") + + if session_id: + body = { + "sessions": { + session_id: body + } + } + + if room_id: + body = { + "rooms": { + room_id: body + } + } + + yield self.e2e_room_keys_handler.upload_room_keys( + user_id, version, body + ) + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_GET(self, request, room_id, session_id): + """ + Retrieves one or more encrypted E2E room keys for backup purposes. + Symmetric with the PUT version of the API. + + room_id: the ID of the room to retrieve the keys for (optional) + session_id: the ID for the E2E room keys to retrieve the keys for (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + + Returns as follows: + + GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + GET /room_keys/keys?version=1 HTTP/1.1 + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + room_keys = yield self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ) + + if session_id: + room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + elif room_id: + room_keys = room_keys['rooms'][room_id] + + defer.returnValue((200, room_keys)) + + @defer.inlineCallbacks + def on_DELETE(self, request, room_id, session_id): + """ + Deletes one or more encrypted E2E room keys for a user for backup purposes. + + DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 + HTTP/1.1 200 OK + {} + + room_id: the ID of the room whose keys to delete (optional) + session_id: the ID for the E2E session to delete (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + """ + + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + yield self.e2e_room_keys_handler.delete_room_keys( + user_id, version, room_id, session_id + ) + defer.returnValue((200, {})) + + +class RoomKeysNewVersionServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/version$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysNewVersionServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + """ + Create a new backup version for this user's room_keys with the given + info. The version is allocated by the server and returned to the user + in the response. This API is intended to be used whenever the user + changes the encryption key for their backups, ensuring that backups + encrypted with different keys don't collide. + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + The algorithm passed in the version info is a reverse-DNS namespaced + identifier to describe the format of the encrypted backupped keys. + + The auth_data is { user_id: "user_id", nonce: <random string> } + encrypted using the algorithm and current encryption key described above. + + POST /room_keys/version + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + HTTP/1.1 200 OK + Content-Type: application/json + { + "version": 12345 + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + new_version = yield self.e2e_room_keys_handler.create_version( + user_id, info + ) + defer.returnValue((200, {"version": new_version})) + + # we deliberately don't have a PUT /version, as these things really should + # be immutable to avoid people footgunning + + +class RoomKeysVersionServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/version(/(?P<version>[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysVersionServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_GET(self, request, version): + """ + Retrieve the version information about a given version of the user's + room_keys backup. If the version part is missing, returns info about the + most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version/12345 HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = yield self.e2e_room_keys_handler.get_version_info( + user_id, version + ) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + defer.returnValue((200, info)) + + @defer.inlineCallbacks + def on_DELETE(self, request, version): + """ + Delete the information about a given version of the user's + room_keys backup. If the version part is missing, deletes the most + current backup version (if any). Doesn't delete the actual room data. + + DELETE /room_keys/version/12345 HTTP/1.1 + HTTP/1.1 200 OK + {} + """ + if version is None: + raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) + + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + yield self.e2e_room_keys_handler.delete_version( + user_id, version + ) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + RoomKeysServlet(hs).register(http_server) + RoomKeysVersionServlet(hs).register(http_server) + RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py new file mode 100644 index 0000000000..e6356101fd --- /dev/null +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import KNOWN_ROOM_VERSIONS +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RoomUpgradeRestServlet(RestServlet): + """Handler for room uprade requests. + + Handles requests of the form: + + POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1 + Content-Type: application/json + + { + "new_version": "2", + } + + Creates a new room and shuts down the old one. Returns the ID of the new room. + + Args: + hs (synapse.server.HomeServer): + """ + PATTERNS = client_v2_patterns( + # /rooms/$roomid/upgrade + "/rooms/(?P<room_id>[^/]*)/upgrade$", + v2_alpha=False, + ) + + def __init__(self, hs): + super(RoomUpgradeRestServlet, self).__init__() + self._hs = hs + self._room_creation_handler = hs.get_room_creation_handler() + self._auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self._auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("new_version", )) + new_version = content["new_version"] + + if new_version not in KNOWN_ROOM_VERSIONS: + raise SynapseError( + 400, + "Your homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + + new_room_id = yield self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + + ret = { + "replacement_room": new_room_id, + } + + defer.returnValue((200, ret)) + + +def register_servlets(hs, http_server): + RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 7362e1858d..89b82b0591 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -137,27 +137,31 @@ class ConsentResource(Resource): request (twisted.web.http.Request): """ - version = parse_string(request, "v", - default=self._default_consent_version) - username = parse_string(request, "u", required=True) - userhmac = parse_string(request, "h", required=True, encoding=None) - - self._check_hash(username, userhmac) - - if username.startswith('@'): - qualified_user_id = username - else: - qualified_user_id = UserID(username, self.hs.hostname).to_string() - - u = yield self.store.get_user_by_id(qualified_user_id) - if u is None: - raise NotFoundError("Unknown user") + version = parse_string(request, "v", default=self._default_consent_version) + username = parse_string(request, "u", required=False, default="") + userhmac = None + has_consented = False + public_version = username == "" + if not public_version: + userhmac = parse_string(request, "h", required=True, encoding=None) + + self._check_hash(username, userhmac) + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + u = yield self.store.get_user_by_id(qualified_user_id) + if u is None: + raise NotFoundError("Unknown user") + has_consented = u["consent_version"] == version try: self._render_template( request, "%s.html" % (version,), user=username, userhmac=userhmac, version=version, - has_consented=(u["consent_version"] == version), + has_consented=has_consented, public_version=public_version, ) except TemplateNotFound: raise NotFoundError("Unknown policy version") diff --git a/synapse/rest/key/v1/__init__.py b/synapse/rest/key/v1/__init__.py deleted file mode 100644 index fe0ac3f8e9..0000000000 --- a/synapse/rest/key/v1/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py deleted file mode 100644 index 38eb2ee23f..0000000000 --- a/synapse/rest/key/v1/server_key_resource.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from canonicaljson import encode_canonical_json -from signedjson.sign import sign_json -from unpaddedbase64 import encode_base64 - -from OpenSSL import crypto -from twisted.web.resource import Resource - -from synapse.http.server import respond_with_json_bytes - -logger = logging.getLogger(__name__) - - -class LocalKey(Resource): - """HTTP resource containing encoding the TLS X.509 certificate and NACL - signature verification keys for this server:: - - GET /key HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - { - "server_name": "this.server.example.com" - "verify_keys": { - "algorithm:version": # base64 encoded NACL verification key. - }, - "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. - "signatures": { - "this.server.example.com": { - "algorithm:version": # NACL signature for this server. - } - } - } - """ - - def __init__(self, hs): - self.response_body = encode_canonical_json( - self.response_json_object(hs.config) - ) - Resource.__init__(self) - - @staticmethod - def response_json_object(server_config): - verify_keys = {} - for key in server_config.signing_key: - verify_key_bytes = key.verify_key.encode() - key_id = "%s:%s" % (key.alg, key.version) - verify_keys[key_id] = encode_base64(verify_key_bytes) - - x509_certificate_bytes = crypto.dump_certificate( - crypto.FILETYPE_ASN1, - server_config.tls_certificate - ) - json_object = { - u"server_name": server_config.server_name, - u"verify_keys": verify_keys, - u"tls_certificate": encode_base64(x509_certificate_bytes) - } - for key in server_config.signing_key: - json_object = sign_json( - json_object, - server_config.server_name, - key, - ) - - return json_object - - def render_GET(self, request): - return respond_with_json_bytes( - request, 200, self.response_body, - ) - - def getChild(self, name, request): - if name == b'': - return self diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index ca90964d1d..f911b120b1 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -52,6 +52,7 @@ class DownloadResource(Resource): b" script-src 'none';" b" plugin-types application/pdf;" b" style-src 'unsafe-inline';" + b" media-src 'self';" b" object-src 'self';" ) server_name, media_id, name = parse_media_id(request) diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py deleted file mode 100644 index bdbd8d50dd..0000000000 --- a/synapse/rest/media/v1/identicon_resource.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pydenticon import Generator - -from twisted.web.resource import Resource - -from synapse.http.servlet import parse_integer - -FOREGROUND = [ - "rgb(45,79,255)", - "rgb(254,180,44)", - "rgb(226,121,234)", - "rgb(30,179,253)", - "rgb(232,77,65)", - "rgb(49,203,115)", - "rgb(141,69,170)" -] - -BACKGROUND = "rgb(224,224,224)" -SIZE = 5 - - -class IdenticonResource(Resource): - isLeaf = True - - def __init__(self): - Resource.__init__(self) - self.generator = Generator( - SIZE, SIZE, foreground=FOREGROUND, background=BACKGROUND, - ) - - def generate_identicon(self, name, width, height): - v_padding = width % SIZE - h_padding = height % SIZE - top_padding = v_padding // 2 - left_padding = h_padding // 2 - bottom_padding = v_padding - top_padding - right_padding = h_padding - left_padding - width -= v_padding - height -= h_padding - padding = (top_padding, bottom_padding, left_padding, right_padding) - identicon = self.generator.generate( - name, width, height, padding=padding - ) - return identicon - - def render_GET(self, request): - name = "/".join(request.postpath) - width = parse_integer(request, "width", default=96) - height = parse_integer(request, "height", default=96) - identicon_bytes = self.generate_identicon(name, width, height) - request.setHeader(b"Content-Type", b"image/png") - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) - return identicon_bytes diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a828ff4438..d6c5f07af0 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -25,7 +25,7 @@ from six.moves.urllib import parse as urlparse import twisted.internet.error import twisted.web.http -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -36,8 +36,8 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import logcontext from synapse.util.async_helpers import Linearizer -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string @@ -45,7 +45,6 @@ from ._base import FileInfo, respond_404, respond_with_responder from .config_resource import MediaConfigResource from .download_resource import DownloadResource from .filepath import MediaFilePaths -from .identicon_resource import IdenticonResource from .media_storage import MediaStorage from .preview_url_resource import PreviewUrlResource from .storage_provider import StorageProviderWrapper @@ -492,10 +491,11 @@ class MediaRepository(object): )) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type - )) + ) if t_byte_source: try: @@ -534,10 +534,11 @@ class MediaRepository(object): )) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type - )) + ) if t_byte_source: try: @@ -620,15 +621,17 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type, - )) + ) elif t_method == "scale": - t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + t_byte_source = yield logcontext.defer_to_thread( + self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type, - )) + ) else: logger.error("Unrecognized method: %r", t_method) continue @@ -765,7 +768,6 @@ class MediaRepositoryResource(Resource): self.putChild(b"thumbnail", ThumbnailResource( hs, media_repo, media_repo.media_storage, )) - self.putChild(b"identicon", IdenticonResource()) if hs.config.url_preview_enabled: self.putChild(b"preview_url", PreviewUrlResource( hs, media_repo, media_repo.media_storage, diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a6189224ee..896078fe76 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -21,9 +21,10 @@ import sys import six -from twisted.internet import defer, threads +from twisted.internet import defer from twisted.protocols.basic import FileSender +from synapse.util import logcontext from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.logcontext import make_deferred_yieldable @@ -64,9 +65,10 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield make_deferred_yieldable(threads.deferToThread( + yield logcontext.defer_to_thread( + self.hs.get_reactor(), _write_file_synchronously, source, f, - )) + ) yield finish_cb() defer.returnValue(fname) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index cad2dec33a..1a7bfd6b56 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -79,7 +79,6 @@ class PreviewUrlResource(Resource): # don't spider URLs more often than once an hour expiry_ms=60 * 60 * 1000, ) - self._cache.start() self._cleaner_loop = self.clock.looping_call( self._start_expire_url_cache_data, 10 * 1000, @@ -597,10 +596,13 @@ def _iterate_over_text(tree, *tags_to_ignore): # to be returned. elements = iter([tree]) while True: - el = next(elements) + el = next(elements, None) + if el is None: + return + if isinstance(el, string_types): yield el - elif el is not None and el.tag not in tags_to_ignore: + elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately # return it if the text exists. if el.text: @@ -672,7 +674,7 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # This splits the paragraph into words, but keeping the # (preceeding) whitespace intact so we can easily concat # words back together. - for match in re.finditer("\s*\S+", description): + for match in re.finditer(r"\s*\S+", description): word = match.group() # Keep adding words while the total length is less than diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 7b9f8b4d79..5aa03031f6 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -17,9 +17,10 @@ import logging import os import shutil -from twisted.internet import defer, threads +from twisted.internet import defer from synapse.config._base import Config +from synapse.util import logcontext from synapse.util.logcontext import run_in_background from .media_storage import FileResponder @@ -120,7 +121,8 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return threads.deferToThread( + return logcontext.defer_to_thread( + self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname, ) diff --git a/synapse/server.py b/synapse/server.py index 938a05f9dc..9985687b95 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ import abc import logging from twisted.enterprise import adbapi +from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.api.auth import Auth @@ -51,6 +52,7 @@ from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.initial_sync import InitialSyncHandler @@ -130,6 +132,7 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'e2e_room_keys_handler', 'event_handler', 'event_stream_handler', 'initial_sync_handler', @@ -172,6 +175,7 @@ class HomeServer(object): 'message_handler', 'pagination_handler', 'room_context_handler', + 'sendmail', ] # This is overridden in derived application classes @@ -205,6 +209,7 @@ class HomeServer(object): logger.info("Setting up.") with self.get_db_conn() as conn: self.datastore = self.DATASTORE_CLASS(conn, self) + conn.commit() logger.info("Finished setting up.") def get_reactor(self): @@ -266,6 +271,9 @@ class HomeServer(object): def build_room_creation_handler(self): return RoomCreationHandler(self) + def build_sendmail(self): + return sendmail + def build_state_handler(self): return StateHandler(self) @@ -299,6 +307,9 @@ class HomeServer(object): def build_e2e_keys_handler(self): return E2eKeysHandler(self) + def build_e2e_room_keys_handler(self): + return E2eRoomKeysHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index ce28486233..06cd083a74 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -7,6 +7,9 @@ import synapse.handlers.auth import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys +import synapse.handlers.room +import synapse.handlers.room_member +import synapse.handlers.message import synapse.handlers.set_password import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager @@ -50,6 +53,9 @@ class HomeServer(object): def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: pass + def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: + pass + def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler: pass diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index d7ae22a661..943d5d6bb5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -19,13 +19,14 @@ from collections import namedtuple from six import iteritems, itervalues +import attr from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext -from synapse.state import v1 +from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -95,10 +96,6 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - def start_caching(self): - # TODO: remove this shim - self._state_resolution_handler.start_caching() - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", latest_event_ids=None): @@ -376,15 +373,10 @@ class StateHandler(object): result = yield self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, None, - self._state_map_factory, + state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) - def _state_map_factory(self, ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - @defer.inlineCallbacks def resolve_events(self, room_version, state_sets, event): logger.info( @@ -402,10 +394,10 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_factory( + new_state = yield resolve_events_with_store( room_version, state_set_ids, event_map=state_map, - state_map_factory=self._state_map_factory + state_res_store=StateResolutionStore(self.store), ) new_state = { @@ -428,9 +420,6 @@ class StateResolutionHandler(object): self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") - def start_caching(self): - logger.debug("start_caching") - self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -440,12 +429,10 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - self._state_cache.start() - @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_res_store, ): """Resolves conflicts between a set of state groups @@ -463,9 +450,11 @@ class StateResolutionHandler(object): a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. + events will be requested via state_res_store. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. + + state_res_store (StateResolutionStore) Returns: Deferred[_StateCacheEntry]: resolved state @@ -489,10 +478,10 @@ class StateResolutionHandler(object): # start by assuming we won't have any conflicted state, and build up the new # state map by iterating through the state groups. If we discover a conflict, - # we give up and instead use `resolve_events_with_factory`. + # we give up and instead use `resolve_events_with_store`. # # XXX: is this actually worthwhile, or should we just let - # resolve_events_with_factory do it? + # resolve_events_with_store do it? new_state = {} conflicted_state = False for st in itervalues(state_groups_ids): @@ -507,11 +496,11 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_factory( + new_state = yield resolve_events_with_store( room_version, list(itervalues(state_groups_ids)), event_map=event_map, - state_map_factory=state_map_factory, + state_res_store=state_res_store, ) # if the new state matches any of the input state groups, we can @@ -592,7 +581,7 @@ def _make_state_cache_entry( ) -def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): +def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): """ Args: room_version(str): Version of the room @@ -608,17 +597,19 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + state_res_store (StateResolutionStore) Returns Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): - return v1.resolve_events_with_factory( - state_sets, event_map, state_map_factory, + if room_version == RoomVersions.V1: + return v1.resolve_events_with_store( + state_sets, event_map, state_res_store.get_events, + ) + elif room_version in (RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST): + return v2.resolve_events_with_store( + state_sets, event_map, state_res_store, ) else: # This should only happen if we added a version but forgot to add it to @@ -626,3 +617,54 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f raise Exception( "No state resolution algorithm defined for version %r" % (room_version,) ) + + +@attr.s +class StateResolutionStore(object): + """Interface that allows state resolution algorithms to access the database + in well defined way. + + Args: + store (DataStore) + """ + + store = attr.ib() + + def get_events(self, event_ids, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + """ + + return self.store.get_events( + event_ids, + check_redacted=False, + get_prev_content=False, + allow_rejected=allow_rejected, + ) + + def get_auth_chain(self, event_ids): + """Gets the full auth chain for a set of events (including rejected + events). + + Includes the given event IDs in the result. + + Note that: + 1. All events must be state events. + 2. For v1 rooms this may not have the full auth chain in the + presence of rejected events + + Args: + event_ids (list): The event IDs of the events to fetch the auth + chain for. Must be state events. + + Returns: + Deferred[list[str]]: List of event IDs of the auth chain. + """ + + return self.store.get_auth_chain_ids(event_ids, include_given=True) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index c95477d318..70a981f4a2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -31,7 +31,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_factory(state_sets, event_map, state_map_factory): +def resolve_events_with_store(state_sets, event_map, state_map_factory): """ Args: state_sets(list): List of dicts of (type, state_key) -> event_id, @@ -65,10 +65,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): for event_ids in itervalues(conflicted_state) for event_id in event_ids ) + needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) - logger.info("Asking for %d conflicted events", len(needed_events)) + logger.info( + "Asking for %d/%d conflicted events", + len(needed_events), + needed_event_count, + ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) @@ -85,11 +90,16 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): ) new_needed_events = set(itervalues(auth_events)) + new_needed_event_count = len(new_needed_events) new_needed_events -= needed_events if event_map is not None: new_needed_events -= set(iterkeys(event_map)) - logger.info("Asking for %d auth events", len(new_needed_events)) + logger.info( + "Asking for %d/%d auth events", + len(new_needed_events), + new_needed_event_count, + ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) diff --git a/synapse/state/v2.py b/synapse/state/v2.py new file mode 100644 index 0000000000..dbc9688c56 --- /dev/null +++ b/synapse/state/v2.py @@ -0,0 +1,548 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +import itertools +import logging + +from six import iteritems, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def resolve_events_with_store(state_sets, event_map, state_res_store): + """Resolves the state using the v2 state resolution algorithm + + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_res_store. + + If None, all events will be fetched via state_res_store. + + state_res_store (StateResolutionStore) + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + + logger.debug("Computing conflicted state") + + # We use event_map as a cache, so if its None we need to initialize it + if event_map is None: + event_map = {} + + # First split up the un/conflicted state + unconflicted_state, conflicted_state = _seperate(state_sets) + + if not conflicted_state: + defer.returnValue(unconflicted_state) + + logger.debug("%d conflicted state entries", len(conflicted_state)) + logger.debug("Calculating auth chain difference") + + # Also fetch all auth events that appear in only some of the state sets' + # auth chains. + auth_diff = yield _get_auth_chain_difference( + state_sets, event_map, state_res_store, + ) + + full_conflicted_set = set(itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), + auth_diff, + )) + + events = yield state_res_store.get_events([ + eid for eid in full_conflicted_set + if eid not in event_map + ], allow_rejected=True) + event_map.update(events) + + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + + logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) + + # Get and sort all the power events (kicks/bans/etc) + power_events = ( + eid for eid in full_conflicted_set + if _is_power_event(event_map[eid]) + ) + + sorted_power_events = yield _reverse_topological_power_sort( + power_events, + event_map, + state_res_store, + full_conflicted_set, + ) + + logger.debug("sorted %d power events", len(sorted_power_events)) + + # Now sequentially auth each one + resolved_state = yield _iterative_auth_checks( + sorted_power_events, unconflicted_state, event_map, + state_res_store, + ) + + logger.debug("resolved power events") + + # OK, so we've now resolved the power events. Now sort the remaining + # events using the mainline of the resolved power level. + + leftover_events = [ + ev_id + for ev_id in full_conflicted_set + if ev_id not in sorted_power_events + ] + + logger.debug("sorting %d remaining events", len(leftover_events)) + + pl = resolved_state.get((EventTypes.PowerLevels, ""), None) + leftover_events = yield _mainline_sort( + leftover_events, pl, event_map, state_res_store, + ) + + logger.debug("resolving remaining events") + + resolved_state = yield _iterative_auth_checks( + leftover_events, resolved_state, event_map, + state_res_store, + ) + + logger.debug("resolved") + + # We make sure that unconflicted state always still applies. + resolved_state.update(unconflicted_state) + + logger.debug("done") + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _get_power_level_for_sender(event_id, event_map, state_res_store): + """Return the power level of the sender of the given event according to + their auth events. + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + event = yield _get_event(event_id, event_map, state_res_store) + + pl = None + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + pl = aev + break + + if pl is None: + # Couldn't find power level. Check if they're the creator of the room + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.Create, ""): + if aev.content.get("creator") == event.sender: + defer.returnValue(100) + break + defer.returnValue(0) + + level = pl.content.get("users", {}).get(event.sender) + if level is None: + level = pl.content.get("users_default", 0) + + if level is None: + defer.returnValue(0) + else: + defer.returnValue(int(level)) + + +@defer.inlineCallbacks +def _get_auth_chain_difference(state_sets, event_map, state_res_store): + """Compare the auth chains of each state set and return the set of events + that only appear in some but not all of the auth chains. + + Args: + state_sets (list) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[set[str]]: Set of event IDs + """ + common = set(itervalues(state_sets[0])).intersection( + *(itervalues(s) for s in state_sets[1:]) + ) + + auth_sets = [] + for state_set in state_sets: + auth_ids = set( + eid + for key, eid in iteritems(state_set) + if (key[0] in ( + EventTypes.Member, + EventTypes.ThirdPartyInvite, + ) or key in ( + (EventTypes.PowerLevels, ''), + (EventTypes.Create, ''), + (EventTypes.JoinRules, ''), + )) and eid not in common + ) + + auth_chain = yield state_res_store.get_auth_chain(auth_ids) + auth_ids.update(auth_chain) + + auth_sets.append(auth_ids) + + intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) + union = set().union(*auth_sets) + + defer.returnValue(union - intersection) + + +def _seperate(state_sets): + """Return the unconflicted and conflicted state. This is different than in + the original algorithm, as this defines a key to be conflicted if one of + the state sets doesn't have that key. + + Args: + state_sets (list) + + Returns: + tuple[dict, dict]: A tuple of unconflicted and conflicted state. The + conflicted state dict is a map from type/state_key to set of event IDs + """ + unconflicted_state = {} + conflicted_state = {} + + for key in set(itertools.chain.from_iterable(state_sets)): + event_ids = set(state_set.get(key) for state_set in state_sets) + if len(event_ids) == 1: + unconflicted_state[key] = event_ids.pop() + else: + event_ids.discard(None) + conflicted_state[key] = event_ids + + return unconflicted_state, conflicted_state + + +def _is_power_event(event): + """Return whether or not the event is a "power event", as defined by the + v2 state resolution algorithm + + Args: + event (FrozenEvent) + + Returns: + boolean + """ + if (event.type, event.state_key) in ( + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Create, ""), + ): + return True + + if event.type == EventTypes.Member: + if event.membership in ('leave', 'ban'): + return event.sender != event.state_key + + return False + + +@defer.inlineCallbacks +def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, + state_res_store, auth_diff): + """Helper function for _reverse_topological_power_sort that add the event + and its auth chain (that is in the auth diff) to the graph + + Args: + graph (dict[str, set[str]]): A map from event ID to the events auth + event IDs + event_id (str): Event to add to the graph + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + """ + + state = [event_id] + while state: + eid = state.pop() + graph.setdefault(eid, set()) + + event = yield _get_event(eid, event_map, state_res_store) + for aid, _ in event.auth_events: + if aid in auth_diff: + if aid not in graph: + state.append(aid) + + graph.setdefault(eid, set()).add(aid) + + +@defer.inlineCallbacks +def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): + """Returns a list of the event_ids sorted by reverse topological ordering, + and then by power level and origin_server_ts + + Args: + event_ids (list[str]): The events to sort + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + + Returns: + Deferred[list[str]]: The sorted list + """ + + graph = {} + for event_id in event_ids: + yield _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff, + ) + + event_to_pl = {} + for event_id in graph: + pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + event_to_pl[event_id] = pl + + def _get_power_order(event_id): + ev = event_map[event_id] + pl = event_to_pl[event_id] + + return -pl, ev.origin_server_ts, event_id + + # Note: graph is modified during the sort + it = lexicographical_topological_sort( + graph, + key=_get_power_order, + ) + sorted_events = list(it) + + defer.returnValue(sorted_events) + + +@defer.inlineCallbacks +def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): + """Sequentially apply auth checks to each event in given list, updating the + state as it goes along. + + Args: + event_ids (list[str]): Ordered list of events to apply auth checks to + base_state (dict[tuple[str, str], str]): The set of state to start with + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[dict[tuple[str, str], str]]: Returns the final updated state + """ + resolved_state = base_state.copy() + + for event_id in event_ids: + event = event_map[event_id] + + auth_events = {} + for aid, _ in event.auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[(ev.type, ev.state_key)] = ev + + for key in event_auth.auth_types_for_event(event): + if key in resolved_state: + ev_id = resolved_state[key] + ev = yield _get_event(ev_id, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[key] = event_map[ev_id] + + try: + event_auth.check( + event, auth_events, + do_sig_check=False, + do_size_check=False + ) + + resolved_state[(event.type, event.state_key)] = event_id + except AuthError: + pass + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _mainline_sort(event_ids, resolved_power_event_id, event_map, + state_res_store): + """Returns a sorted list of event_ids sorted by mainline ordering based on + the given event resolved_power_event_id + + Args: + event_ids (list[str]): Events to sort + resolved_power_event_id (str): The final resolved power level event ID + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[list[str]]: The sorted list + """ + mainline = [] + pl = resolved_power_event_id + while pl: + mainline.append(pl) + pl_ev = yield _get_event(pl, event_map, state_res_store) + auth_events = pl_ev.auth_events + pl = None + for aid, _ in auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): + pl = aid + break + + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} + + event_ids = list(event_ids) + + order_map = {} + for ev_id in event_ids: + depth = yield _get_mainline_depth_for_event( + event_map[ev_id], mainline_map, + event_map, state_res_store, + ) + order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + + event_ids.sort(key=lambda ev_id: order_map[ev_id]) + + defer.returnValue(event_ids) + + +@defer.inlineCallbacks +def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): + """Get the mainline depths for the given event based on the mainline map + + Args: + event (FrozenEvent) + mainline_map (dict[str, int]): Map from event_id to mainline depth for + events in the mainline. + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + + # We do an iterative search, replacing `event with the power level in its + # auth events (if any) + while event: + depth = mainline_map.get(event.event_id) + if depth is not None: + defer.returnValue(depth) + + auth_events = event.auth_events + event = None + + for aid, _ in auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + event = aev + break + + # Didn't find a power level auth event, so we just return 0 + defer.returnValue(0) + + +@defer.inlineCallbacks +def _get_event(event_id, event_map, state_res_store): + """Helper function to look up event in event_map, falling back to looking + it up in the store + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[FrozenEvent] + """ + if event_id not in event_map: + events = yield state_res_store.get_events([event_id], allow_rejected=True) + event_map.update(events) + defer.returnValue(event_map[event_id]) + + +def lexicographical_topological_sort(graph, key): + """Performs a lexicographic reverse topological sort on the graph. + + This returns a reverse topological sort (i.e. if node A references B then B + appears before A in the sort), with ties broken lexicographically based on + return value of the `key` function. + + NOTE: `graph` is modified during the sort. + + Args: + graph (dict[str, set[str]]): A representation of the graph where each + node is a key in the dict and its value are the nodes edges. + key (func): A function that takes a node and returns a value that is + comparable and used to order nodes + + Yields: + str: The next node in the topological sort + """ + + # Note, this is basically Kahn's algorithm except we look at nodes with no + # outgoing edges, c.f. + # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + outdegree_map = graph + reverse_graph = {} + + # Lists of nodes with zero out degree. Is actually a tuple of + # `(key(node), node)` so that sorting does the right thing + zero_outdegree = [] + + for node, edges in iteritems(graph): + if len(edges) == 0: + zero_outdegree.append((key(node), node)) + + reverse_graph.setdefault(node, set()) + for edge in edges: + reverse_graph.setdefault(edge, set()).add(node) + + # heapq is a built in implementation of a sorted queue. + heapq.heapify(zero_outdegree) + + while zero_outdegree: + _, node = heapq.heappop(zero_outdegree) + + for parent in reverse_graph[node]: + out = outdegree_map[parent] + out.discard(node) + if len(out) == 0: + heapq.heappush(zero_outdegree, (key(parent), parent)) + + yield node diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 23b4a8d76d..53c685c173 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .engines import PostgresEngine from .event_federation import EventFederationStore @@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, + EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index be61147b9b..d9d0255d0b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,7 +18,7 @@ import threading import time from six import PY2, iteritems, iterkeys, itervalues -from six.moves import intern, range +from six.moves import builtins, intern, range from canonicaljson import json from prometheus_client import Histogram @@ -1233,7 +1233,7 @@ def db_to_json(db_content): # psycopg2 on Python 2 returns buffer objects, which we need to cast to # bytes to decode - if PY2 and isinstance(db_content, buffer): + if PY2 and isinstance(db_content, builtins.buffer): db_content = bytes(db_content) # Decode it to a Unicode string before feeding it to json.loads, so we diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8fc678fa67..9ad17b7c25 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": last_seen, - }, - lock=False, - ) + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) @defer.inlineCallbacks def get_last_client_ip_by_device(self, user_id, device_id): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d10ff9e4b9..ecdab34e7d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -22,14 +22,19 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore, db_to_json +from ._base import Cache, db_to_json logger = logging.getLogger(__name__) +DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( + "drop_device_list_streams_non_unique_indexes" +) -class DeviceStore(SQLBaseStore): + +class DeviceStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): super(DeviceStore, self).__init__(db_conn, hs) @@ -52,6 +57,30 @@ class DeviceStore(SQLBaseStore): columns=["user_id", "device_id"], ) + # create a unique index on device_lists_remote_cache + self.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -239,7 +268,19 @@ class DeviceStore(SQLBaseStore): def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): - """Updates a single user's device in the cache. + """Updates a single device in the cache of a remote user's devicelist. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + device_id (str): ID of decivice being updated + content (dict): new data on this device + stream_id (int): the version of the device list + + Returns: + Deferred[None] """ return self.runInteraction( "update_remote_device_list_cache_entry", @@ -272,7 +313,11 @@ class DeviceStore(SQLBaseStore): }, values={ "content": json.dumps(content), - } + }, + + # we don't need to lock, because we assume we are the only thread + # updating this user's devices. + lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) @@ -289,11 +334,26 @@ class DeviceStore(SQLBaseStore): }, values={ "stream_id": stream_id, - } + }, + + # again, we can assume we are the only thread updating this user's + # extremity. + lock=False, ) def update_remote_device_list_cache(self, user_id, devices, stream_id): - """Replace the cache of the remote user's devices. + """Replace the entire cache of the remote user's devices. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + devices (list[dict]): list of device objects supplied over federation + stream_id (int): the version of the device list + + Returns: + Deferred[None] """ return self.runInteraction( "update_remote_device_list_cache", @@ -338,7 +398,11 @@ class DeviceStore(SQLBaseStore): }, values={ "stream_id": stream_id, - } + }, + + # we don't need to lock, because we can assume we are the only thread + # updating this user's extremity. + lock=False, ) def get_devices_by_remote(self, destination, from_stream_id): @@ -589,10 +653,14 @@ class DeviceStore(SQLBaseStore): combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. """ + # We do a group by here as there can be a large number of duplicate + # entries, since we throw away device IDs. sql = """ - SELECT stream_id, user_id, destination FROM device_lists_stream + SELECT MAX(stream_id) AS stream_id, user_id, destination + FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id, destination """ return self._execute( "get_all_device_list_changes_for_remotes", None, @@ -718,3 +786,19 @@ class DeviceStore(SQLBaseStore): "_prune_old_outbound_device_pokes", _prune_txn, ) + + @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_lists_remote_cache_id" + ) + txn.execute( + "DROP INDEX IF EXISTS device_lists_remote_extremeties_id" + ) + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + defer.returnValue(1) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 808194236a..61a029a53c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): }, retcol="creator", desc="get_room_alias_creator", - allow_none=True ) @cached(max_entries=5000) @@ -91,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an associatin between a room alias and room_id/servers + """ Creates an association between a room alias and room_id/servers Args: room_alias (RoomAlias) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py new file mode 100644 index 0000000000..f25ded2295 --- /dev/null +++ b/synapse/storage/e2e_room_keys.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError + +from ._base import SQLBaseStore + + +class EndToEndRoomKeyStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_e2e_room_key(self, user_id, version, room_id, session_id): + """Get the encrypted E2E room key for a given session from a given + backup version of room_keys. We only store the 'best' room key for a given + session at a given time, as determined by the handler. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): the ID of the room whose keys we're querying. + This is a bit redundant as it's implied by the session_id, but + we include for consistency with the rest of the API. + session_id(str): the session whose room_key we're querying. + + Returns: + A deferred dict giving the session_data and message metadata for + this room key. + """ + + row = yield self._simple_select_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_key", + ) + + row["session_data"] = json.loads(row["session_data"]) + + defer.returnValue(row) + + @defer.inlineCallbacks + def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces or inserts the encrypted E2E room key for a given session in + a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + Raises: + StoreError + """ + + yield self._simple_upsert( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + }, + values={ + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": json.dumps(room_key['session_data']), + }, + lock=False, + ) + + @defer.inlineCallbacks + def get_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + rows = yield self._simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "user_id", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {'rooms': {}} + for row in rows: + room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) + room_entry['sessions'][row['session_id']] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + "is_verified": row["is_verified"], + "session_data": json.loads(row["session_data"]), + } + + defer.returnValue(sessions) + + @defer.inlineCallbacks + def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + yield self._simple_delete( + table="e2e_room_keys", + keyvalues=keyvalues, + desc="delete_e2e_room_keys", + ) + + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", + (user_id,) + ) + row = txn.fetchone() + if not row: + raise StoreError(404, 'No current backup version') + return row[0] + + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. + + Args: + user_id(str): the user whose backup we're querying + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present + Returns: + A deferred dict giving the info metadata for this backup version + """ + + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + result = self._simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=( + "version", + "algorithm", + "auth_data", + ), + ) + result["auth_data"] = json.loads(result["auth_data"]) + return result + + return self.runInteraction( + "get_e2e_room_keys_version_info", + _get_e2e_room_keys_version_info_txn + ) + + def create_e2e_room_keys_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID + """ + + def _create_e2e_room_keys_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,) + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = '0' + + new_version = str(int(current_version) + 1) + + self._simple_insert_txn( + txn, + table="e2e_room_keys_versions", + values={ + "user_id": user_id, + "version": new_version, + "algorithm": info["algorithm"], + "auth_data": json.dumps(info["auth_data"]), + }, + ) + + return new_version + + return self.runInteraction( + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn + ) + + def delete_e2e_room_keys_version(self, user_id, version=None): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. + """ + + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + return self._simple_update_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + }, + updatevalues={ + "deleted": 1, + } + ) + + return self.runInteraction( + "delete_e2e_room_keys_version", + _delete_e2e_room_keys_version_txn + ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 24345b20a6..3faca2a042 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -376,33 +376,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, - limit, min_depth): + limit): ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, - room_id, earliest_events, latest_events, limit, min_depth + room_id, earliest_events, latest_events, limit, ) - events = yield self._get_events(ids) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - defer.returnValue(events[:limit]) + defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, - limit, min_depth): - - earliest_events = set(earliest_events) - front = set(latest_events) - earliest_events + limit): - event_results = set() + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] query = ( "SELECT prev_event_id FROM event_edges " - "WHERE event_id = ? AND is_state = ? " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -411,18 +403,20 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, for event_id in front: txn.execute( query, - (event_id, False, limit - len(event_results)) + (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn: - new_front.add(e_id) + new_results = set(t[0] for t in txn) - seen_events - new_front -= earliest_events - new_front -= event_results + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) front = new_front - event_results |= new_front + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() return event_results diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e7487311ce..919e855f3b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -34,10 +34,13 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.state import StateGroupWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder @@ -203,7 +206,8 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore): +class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, + BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -386,12 +390,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) for room_id, ev_ctx_rm in iteritems(events_by_room): - # Work out new extremities by recursively adding and removing - # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremeties( + new_latest_event_ids = yield self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -400,6 +402,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # No change in extremities, so no change in state continue + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + new_forward_extremeties[room_id] = new_latest_event_ids len_1 = ( @@ -517,44 +525,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremeties for a room given events to + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - new_latest_event_ids = set(latest_event_ids) - # First, add all the new events to the list - new_latest_event_ids.update( - event.event_id for event, ctx in event_contexts + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event for event, ctx in event_contexts if not event.internal_metadata.is_outlier() and not ctx.rejected + ] + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update( + event.event_id for event in new_events ) - # Now remove all events that are referenced by the to-be-added events - new_latest_event_ids.difference_update( + + # Now remove all events which are prev_events of any of the new events + result.difference_update( e_id - for event, ctx in event_contexts + for event in new_events for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() and not ctx.rejected ) - # And finally remove any events that are referenced by previously added - # events. - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=list(new_latest_event_ids), - retcols=["prev_event_id"], - keyvalues={ - "is_state": False, - }, - desc="_calculate_new_extremeties", - ) + # Finally, remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) - new_latest_event_ids.difference_update( - row["prev_event_id"] for row in rows - ) + defer.returnValue(result) - defer.returnValue(new_latest_event_ids) + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events(txn, batch): + sql = """ + SELECT prev_event_id + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + AND rejections.event_id IS NULL + """ % ( + ",".join("?" for _ in batch), + ) + + txn.execute(sql, batch) + results.extend(r[0] for r in txn) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events, + chunk, + ) + + defer.returnValue(results) @defer.inlineCallbacks def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, @@ -586,10 +629,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore the new current state is only returned if we've already calculated it. """ - - if not new_latest_event_ids: - return - # map from state_group to ((type, key) -> event_id) state map state_groups_map = {} @@ -695,11 +734,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Ok, we need to defer to the state handler to resolve our state sets. - def get_events(ev_ids): - return self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - state_groups = { sg: state_groups_map[sg] for sg in new_state_groups } @@ -709,7 +743,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, + state_res_store=StateResolutionStore(self) ) defer.returnValue((res.state, None)) @@ -818,6 +853,27 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self._simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id, _ in event.auth_events + if event.is_state() + ], + ) + # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( @@ -1293,21 +1349,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore txn, event.room_id, event.redacts ) - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id, _ in event.auth_events - if event.is_state() - ], - ) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1995,62 +2036,44 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] finding redundant state groups") - # Get all state groups that are only referenced by events that are - # to be deleted. - # This works by first getting state groups that we may want to delete, - # joining against event_to_state_groups to get events that use that - # state group, then left joining against events_to_purge again. Any - # state group where the left join produce *no nulls* are referenced - # only by events that are going to be purged. + # Get all state groups that are referenced by events that are to be + # deleted. We then go and check if they are referenced by other events + # or state groups, and if not we delete them. txn.execute(""" - SELECT state_group FROM - ( - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) - ) AS sp - INNER JOIN event_to_state_groups USING (state_group) - LEFT JOIN events_to_purge AS ep USING (event_id) - GROUP BY state_group - HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0 + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) """) - state_rows = txn.fetchall() - logger.info("[purge] found %i redundant state groups", len(state_rows)) - - # make a set of the redundant state groups, so that we can look them up - # efficiently - state_groups_to_delete = set([sg for sg, in state_rows]) - - # Now we get all the state groups that rely on these state groups - logger.info("[purge] finding state groups which depend on redundant" - " state groups") - remaining_state_groups = [] - for i in range(0, len(state_rows), 100): - chunk = [sg for sg, in state_rows[i:i + 100]] - # look for state groups whose prev_state_group is one we are about - # to delete - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=chunk, - retcols=["state_group"], - keyvalues={}, - ) - remaining_state_groups.extend( - row["state_group"] for row in rows + referenced_state_groups = set(sg for sg, in txn) + logger.info( + "[purge] found %i referenced state groups", + len(referenced_state_groups), + ) + + logger.info("[purge] finding state groups that can be deleted") - # exclude state groups we are about to delete: no point in - # updating them - if row["state_group"] not in state_groups_to_delete + state_groups_to_delete, remaining_state_groups = ( + self._find_unreferenced_groups_during_purge( + txn, referenced_state_groups, ) + ) + + logger.info( + "[purge] found %i state groups to delete", + len(state_groups_to_delete), + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) # Now we turn the state groups that reference to-be-deleted state # groups to non delta versions. for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], types=None + txn, [sg], ) curr_state = curr_state[sg] @@ -2088,11 +2111,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", - state_rows + ((sg,) for sg in state_groups_to_delete), ) txn.executemany( "DELETE FROM state_groups WHERE id = ?", - state_rows + ((sg,) for sg in state_groups_to_delete), ) logger.info("[purge] removing events from event_to_state_groups") @@ -2188,6 +2211,85 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] done") + def _find_unreferenced_groups_during_purge(self, txn, state_groups): + """Used when purging history to figure out which state groups can be + deleted and which need to be de-delta'ed (due to one of its prev groups + being scheduled for deletion). + + Args: + txn + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + tuple[set[int], set[int]]: The set of state groups that can be + deleted and the set of state groups that need to be de-delta'ed + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + # Check if state groups are referenced + sql = """ + SELECT DISTINCT state_group FROM event_to_state_groups + LEFT JOIN events_to_purge AS ep USING (event_id) + WHERE state_group IN (%s) AND ep.event_id IS NULL + """ % (",".join("?" for _ in current_search),) + txn.execute(sql, list(current_search)) + + referenced = set(sg for sg, in txn) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=current_search, + keyvalues={}, + retcols=("prev_state_group", "state_group",), + ) + + prevs = set(row["state_group"] for row in rows) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + for row in rows: + # Note: Each state group can have at most one prev group + graph[row["state_group"]] = row["prev_state_group"] + + to_delete = state_groups_seen - referenced_groups + + to_dedelta = set() + for sg in referenced_groups: + prev_sg = graph.get(sg) + if prev_sg and prev_sg in to_delete: + to_dedelta.add(sg) + + return to_delete, to_dedelta + @defer.inlineCallbacks def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index a1331c1a61..8af17921e3 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 59580949f1..cf4104dc2e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,19 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () + # Do not add more reserved users than the total allowable number + self._initialise_reserved_users( + dbconn.cursor(), + hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], + ) - @defer.inlineCallbacks - def initialise_reserved_users(self, threepids): - store = self.hs.get_datastore() + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ reserved_user_list = [] - # Do not add more reserved users than the total allowable number - for tp in threepids[:self.hs.config.max_mau_value]: - user_id = yield store.get_user_id_by_threepid( + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn( + txn, tp["medium"], tp["address"] ) if user_id: - yield self.upsert_monthly_active_user(user_id) + self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: logger.warning( @@ -55,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def reap_monthly_active_users(self): - """ - Cleans out monthly active user table to ensure that no stale + """Cleans out monthly active user table to ensure that no stale entries exist. Returns: @@ -165,15 +174,44 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update """ - Updates or inserts monthly active user member - Arguments: - user_id (str): user to add/update - Deferred[bool]: True if a new entry was created, False if an - existing one was updated. + is_insert = yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, + user_id + ) + + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. """ - is_insert = yield self._simple_upsert( - desc="upsert_monthly_active_user", + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self._simple_upsert_txn( + txn, table="monthly_active_users", keyvalues={ "user_id": user_id, @@ -181,11 +219,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={ "timestamp": int(self._clock.time_msec()), }, - lock=False, ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + + return is_insert @cached(num_args=1) def user_last_seen_monthly_active(self, user_id): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b364719312..bd740e1e45 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 51 +SCHEMA_VERSION = 52 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index c7987bfcdd..2743b52bad 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -29,7 +29,7 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26b429e307..80d76bf9d7 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -474,17 +474,44 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): - ret = yield self._simple_select_one( + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, + medium, address + ) + defer.returnValue(user_id) + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self._simple_select_one_txn( + txn, "user_threepids", { "medium": medium, "address": address }, - ['user_id'], True, 'get_user_id_by_threepid' + ['user_id'], True ) if ret: - defer.returnValue(ret['user_id']) - defer.returnValue(None) + return ret['user_id'] + return None def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( @@ -567,7 +594,7 @@ class RegistrationStore(RegistrationWorkerStore, def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - regex = re.compile("^@(\d+):") + regex = re.compile(r"^@(\d+):") found = set() diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 61013b8919..41c65e112a 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -47,7 +47,7 @@ class RoomWorkerStore(SQLBaseStore): Args: room_id (str): The ID of the room to retrieve. Returns: - A namedtuple containing the room information, or an empty list. + A dict containing the room information, or None if the room is unknown. """ return self._simple_select_one( table="rooms", diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 54841b3843..dd6dcb65f1 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -20,9 +20,6 @@ CREATE TABLE device_lists_remote_cache ( content TEXT NOT NULL ); -CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); - - -- The last update we got for a user. Empty if we're not receiving updates for -- that user. CREATE TABLE device_lists_remote_extremeties ( @@ -30,7 +27,11 @@ CREATE TABLE device_lists_remote_extremeties ( stream_id TEXT NOT NULL ); -CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); +-- we used to create non-unique indexes on these tables, but as of update 52 we create +-- unique indexes concurrently: +-- +-- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +-- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); -- Stream of device lists updates. Includes both local and remotes diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..c0e66a697d --- /dev/null +++ b/synapse/storage/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql new file mode 100644 index 0000000000..91e03d13e1 --- /dev/null +++ b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is needed to efficiently check for unreferenced state groups during +-- purge. Added events_to_state_group(state_group) index +INSERT into background_updates (update_name, progress_json) + VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql new file mode 100644 index 0000000000..bfa49e6f92 --- /dev/null +++ b/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql @@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_remote_cache +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_remote_cache_unique_idx', '{}'); + +-- and one on device_lists_remote_extremeties +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'device_lists_remote_extremeties_unique_idx', '{}', + + -- doesn't really depend on this, but we need to make sure both happen + -- before we drop the old indexes. + 'device_lists_remote_cache_unique_idx' + ); + +-- once they complete, we can drop the old indexes. +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'drop_device_list_streams_non_unique_indexes', '{}', + 'device_lists_remote_extremeties_unique_idx' + ); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 5623391f6e..158e9dbe7b 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -27,7 +27,7 @@ from ._base import SQLBaseStore # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 4b971efdba..d737bd6778 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from collections import namedtuple from six import iteritems, itervalues from six.moves import range +import attr + from twisted.internet import defer from synapse.api.constants import EventTypes @@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 +@attr.s(slots=True) +class StateFilter(object): + """A filter used when querying for state. + + Attributes: + types (dict[str, set[str]|None]): Map from type to set of state keys (or + None). This specifies which state_keys for the given type to fetch + from the DB. If None then all events with that type are fetched. If + the set is empty then no events with that type are fetched. + include_others (bool): Whether to fetch events with types that do not + appear in `types`. + """ + + types = attr.ib() + include_others = attr.ib(default=False) + + def __attrs_post_init__(self): + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + self.types = { + k: v for k, v in iteritems(self.types) + if v is not None + } + + @staticmethod + def all(): + """Creates a filter that fetches everything. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=True) + + @staticmethod + def none(): + """Creates a filter that fetches nothing. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=False) + + @staticmethod + def from_types(types): + """Creates a filter that only fetches the given types + + Args: + types (Iterable[tuple[str, str|None]]): A list of type and state + keys to fetch. A state_key of None fetches everything for + that type + + Returns: + StateFilter + """ + type_dict = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) + + return StateFilter(types=type_dict) + + @staticmethod + def from_lazy_load_member_list(members): + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members (iterable[str]): Set of user IDs + + Returns: + StateFilter + """ + return StateFilter( + types={EventTypes.Member: set(members)}, + include_others=True, + ) + + def return_expanded(self): + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + StateFilter + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in iteritems(self.types) + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + else: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types={EventTypes.Member: self.types[EventTypes.Member]}, + include_others=True, + ) + + def make_sql_filter_clause(self): + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + tuple[str, list]: The SQL string (may be empty) and arguments. An + empty SQL string is returned when the filter matches everything + (i.e. is "full"). + """ + + where_clause = "" + where_args = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in iteritems(self.types): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % ( + ",".join(["?"] * len(self.types)), + ) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self): + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict): + """Returns the state filtered with by this StateFilter + + Args: + state (dict[tuple[str, str], Any]): The state map to filter + + Returns: + dict[tuple[str, str], Any]: The filtered state map + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in iteritems(state_dict): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self): + """Whether this filter fetches everything or not + + Returns: + bool + """ + return self.include_others and not self.types + + def has_wildcards(self): + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + bool + """ + + return ( + self.include_others + or any( + state_keys is None + for state_keys in itervalues(self.types) + ) + ) + + def concrete_types(self): + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + list[tuple[str,str]] + """ + return [ + (t, s) + for t, state_keys in iteritems(self.types) + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self): + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + tuple[StateFilter, StateFilter]: The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter({EventTypes.Member: state_keys}) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. @@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + Args: room_id (str) - types (list[(Str, (Str|None))]): List of (type, state_key) tuples - which are used to filter the state fetched. `state_key` may be - None, which matches any `state_key` - filtered_types (list[Str]|None): List of types to apply the above filter to. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: - deferred: dict of (type, state_key) -> event + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. """ - include_other_types = False if filtered_types is None else True - def _get_filtered_current_state_ids_txn(txn): results = {} - sql = """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? %s""" - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - - if include_other_types: - unique_types = set(filtered_types) - clause_to_args.append( - ( - "AND type <> ? " * len(unique_types), - list(unique_types) - ) - ) - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - for where_clause, where_args in clause_to_args: - args = [room_id] - args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results return self.runInteraction( @@ -255,7 +549,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -270,7 +574,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -286,7 +590,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -310,21 +616,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types, members=None): + def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. Args: groups(list[int]): list of state group IDs to query - types (Iterable[str, str|None]|None): list of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. If None, all types are returned. - members (bool|None): If not None, then, in addition to any filtering - implied by types, the results are also filtered to only include - member events (if True), or to exclude member events (if False) - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -332,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, members, + self._get_state_groups_from_groups_txn, chunk, state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, members=None, + self, txn, groups, state_filter=StateFilter.all(), ): results = {group: {} for group in groups} - if types is not None: - types = list(set(types)) # deduplicate types list + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -360,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT type, state_key, last_value(event_id) OVER ( + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( PARTITION BY type, state_key ORDER BY state_group ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS event_id FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state ) - %s - """) - - if members is True: - sql += " AND type = '%s'" % (EventTypes.Member,) - elif members is False: - sql += " AND type <> '%s'" % (EventTypes.Member,) - - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types is not None: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] + """ - for where_clause, where_args in clause_to_args: - for group in groups: - args = [group] - args.extend(where_args) + for group in groups: + args = [group] + args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: - where_args = [] - where_clauses = [] - wildcard_types = False - if types is not None: - for typ in types: - if typ[1] is None: - where_clauses.append("(type = ?)") - where_args.append(typ[0]) - wildcard_types = True - else: - where_clauses.append("(type = ? AND state_key = ?)") - where_args.extend([typ[0], typ[1]]) - - where_clause = "AND (%s)" % (" OR ".join(where_clauses)) - else: - where_clause = "" - - if members is True: - where_clause += " AND type = '%s'" % EventTypes.Member - elif members is False: - where_clause += " AND type <> '%s'" % EventTypes.Member + max_entries_returned = state_filter.max_entries_returned() # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) @@ -446,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # without the right indices (which we can't add until # after we finish deduping state, which requires this func) args = [next_group] - if types: - args.extend(where_args) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? %s" % (where_clause,), + " WHERE state_group = ? " + where_clause, args ) results[group].update( @@ -467,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - types is not None and - not wildcard_types and - len(results[group]) == len(types) + max_entries_returned is not None and + len(results[group]) == max_entries_returned ): break @@ -484,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types, filtered_types=None): + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state - dicts for each event. The state dicts will only have the type/state_keys - that are in the `types` list. + dicts for each event. Args: event_ids (list[string]) - types (list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -507,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -526,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from event_id -> (type, state_key) -> event_id @@ -549,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) event_to_state = { event_id: group_to_state[group] @@ -559,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None, filtered_types=None): + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_ids_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -628,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): + def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` Args: cache(DictionaryCache): the state group cache to use group(int): The state group to lookup - types(list[str, str|None]): List of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool indicating if we successfully retrieved all @@ -648,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ is_all, known_absent, state_dict_ids = cache.get(group) - type_to_key = {} + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all # tracks whether any of our requested types are missing from the cache missing_types = False - for typ, state_key in types: - key = (typ, state_key) - - if ( - state_key is None or - (filtered_types is not None and typ not in filtered_types) - ): - type_to_key[typ] = None - # we mark the type as missing from the cache because - # when the cache was populated it might have been done with a - # restricted set of state_keys, so the wildcard will not work - # and the cache may be incomplete. - missing_types = True - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): if key not in state_dict_ids and key not in known_absent: missing_types = True + break - sentinel = object() - - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return filtered_types is not None and typ not in filtered_types - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True - return False - - got_all = is_all - if not got_all: - # the cache is incomplete. We may still have got all the results we need, if - # we don't have any wildcards in the match list. - if not missing_types and filtered_types is None: - got_all = True - - return { - k: v for k, v in iteritems(state_dict_ids) - if include(k[0], k[1]) - }, got_all - - def _get_all_state_from_cache(self, cache, group): - """Checks if group is in cache. See `_get_state_for_groups` - - Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool - indicating if we successfully retrieved all requests state from the - cache, if False we need to query the DB for the missing state. - - Args: - cache(DictionaryCache): the state group cache to use - group: The state group to lookup - """ - is_all, _, state_dict_ids = cache.get(group) - - return state_dict_ids, is_all + return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None, filtered_types=None): + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: groups (iterable[int]): list of state groups for which we want to get the state. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ - if types is not None: - non_member_types = [t for t in types if t[0] != EventTypes.Member] - - if filtered_types is not None and EventTypes.Member not in filtered_types: - # we want all of the membership events - member_types = None - else: - member_types = [t for t in types if t[0] == EventTypes.Member] - else: - non_member_types = None - member_types = None + member_filter, non_member_filter = state_filter.get_member_split() - non_member_state = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, non_member_types, filtered_types, + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, + state_filter=non_member_filter, + ) ) - # XXX: we could skip this entirely if member_types is [] - member_state = yield self._get_state_for_groups_using_cache( - # we set filtered_types=None as member_state only ever contain members. - groups, self._state_group_members_cache, member_types, None, + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, + state_filter=member_filter, + ) ) - state = non_member_state + state = dict(non_member_state) for group in groups: state[group].update(member_state[group]) + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + defer.returnValue(state) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), + state_filter=db_state_filter, + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + defer.returnValue(state) - @defer.inlineCallbacks def _get_state_for_groups_using_cache( - self, groups, cache, types=None, filtered_types=None + self, groups, cache, state_filter, ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -776,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache (DictionaryCache): the cache of group ids to state dicts which we will pass through - either the normal state cache or the specific members state cache. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. """ - if types: - types = frozenset(types) results = {} - missing_groups = [] - if types is not None: - for group in set(groups): - state_dict_ids, got_all = self._get_some_state_from_cache( - cache, group, types, filtered_types - ) - results[group] = state_dict_ids + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids - if not got_all: - missing_groups.append(group) - else: - for group in set(groups): - state_dict_ids, got_all = self._get_all_state_from_cache( - cache, group - ) + if not got_all: + incomplete_groups.add(group) - results[group] = state_dict_ids + return results, incomplete_groups - if not got_all: - missing_groups.append(group) + def _insert_into_cache(self, group_to_state_dict, state_filter, + cache_seq_num_members, cache_seq_num_non_members): + """Inserts results from querying the database into the relevant cache. - if missing_groups: - # Okay, so we have some missing_types, let's fetch them. - cache_seq_num = cache.sequence + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ - # the DictionaryCache knows if it has *all* the state, but - # does not know if it has all of the keys of a particular type, - # which makes wildcard lookups expensive unless we have a complete - # cache. Hence, if we are doing a wildcard lookup, populate the - # cache fully so that we can do an efficient lookup next time. + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) - if filtered_types or (types and any(k is None for (t, k) in types)): - types_to_fetch = None - else: - types_to_fetch = types + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() - group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch, cache == self._state_group_members_cache, - ) + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict = results[group] - - # update the result, filtering by `types`. - if types: - for k, v in iteritems(group_state_dict): - (typ, _) = k - if ( - (k in types or (typ, None) in types) or - (filtered_types and typ not in filtered_types) - ): - state_dict[k] = v + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v else: - state_dict.update(group_state_dict) - - # update the cache with all the things we fetched from the - # database. - cache.update( - cache_seq_num, - key=group, - value=group_state_dict, - fetched_keys=types_to_fetch, - ) + state_dict_non_members[k] = v - defer.returnValue(results) + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) def store_state_group(self, event_id, room_id, prev_group, delta_ids, current_state_ids): @@ -1050,6 +1257,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, db_conn, hs): super(StateStore, self).__init__(db_conn, hs) @@ -1068,6 +1276,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): columns=["state_key"], where_clause="type='m.room.member'", ) + self.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) def _store_event_state_mappings_txn(self, txn, events_and_contexts): state_groups = {} @@ -1167,12 +1381,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], types=None + txn, [prev_group], ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], types=None + txn, [state_group], ) curr_state = curr_state[state_group] diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 4c296d72c0..d6cfdba519 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -630,7 +630,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events""" + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ def get_all_new_events_stream_txn(txn): sql = ( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 0c42bd3322..d8bf953ec0 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,14 +23,14 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.descriptors import cached +from synapse.util.caches.expiringcache import ExpiringCache from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview @@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple( ) ) +SENTINEL = object() + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. @@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore): self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -156,7 +164,7 @@ class TransactionStore(SQLBaseStore): """ pass - @cached(max_entries=10000) + @defer.inlineCallbacks def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -167,10 +175,20 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a dict for the retry scheme """ - return self.runInteraction( + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + defer.returnValue(result) + + result = yield self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + defer.returnValue(result) + def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, @@ -198,8 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # XXX: we could chose to not bother persisting this if our cache thinks - # this is a NOOP + self._destination_retry_cache.pop(destination, None) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, @@ -212,10 +229,6 @@ class TransactionStore(SQLBaseStore): retry_last_ts, retry_interval): self.database_engine.lock_table(txn, "destinations") - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 680ea928c7..0ae7e2ef3b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import re from itertools import islice import attr @@ -68,7 +69,10 @@ class Clock(object): """ call = task.LoopingCall(f) call.clock = self._reactor - call.start(msec / 1000.0, now=False) + d = call.start(msec / 1000.0, now=False) + d.addErrback( + log_failure, "Looping call died", consumeErrors=False, + ) return call def call_later(self, delay, callback, *args, **kwargs): @@ -109,3 +113,53 @@ def batch_iter(iterable, size): sourceiter = iter(iterable) # call islice until it returns an empty tuple return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +def log_failure(failure, msg, consumeErrors=True): + """Creates a function suitable for passing to `Deferred.addErrback` that + logs any failures that occur. + + Args: + msg (str): Message to log + consumeErrors (bool): If true consumes the failure, otherwise passes + on down the callback chain + + Returns: + func(Failure) + """ + + logger.error( + msg, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject() + ) + ) + + if not consumeErrors: + return failure + + +def glob_to_regex(glob): + """Converts a glob to a compiled regex object. + + The regex is anchored at the beginning and end of the string. + + Args: + glob (str) + + Returns: + re.RegexObject + """ + res = '' + for c in glob: + if c == '*': + res = res + '.*' + elif c == '?': + res = res + '.' + else: + res = res + re.escape(c) + + # \A anchors at start of string, \Z at end of string + return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 40c2946129..ec7b2c9672 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -374,29 +374,25 @@ class ReadWriteLock(object): defer.returnValue(_ctx_manager()) -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise defer.TimeoutError(timeout, "Deferred") + return value -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. +def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """The in built twisted `Deferred.addTimeout` fails to time out deferreds + that have a canceller that throws exceptions. This method creates a new + deferred that wraps and times out the given deferred, correctly handling + the case where the given deferred's canceller throws. - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - + deferred (Deferred) + timeout (float): Timeout in seconds + reactor (twisted.internet.reactor): The twisted reactor to use on_timeout_cancel (callable): A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. @@ -406,48 +402,10 @@ def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): the timeout. The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. - """ - timed_out = [False] - - def time_it_out(): - timed_out[0] = True - deferred.cancel() - - delayed_call = reactor.callLater(timeout, time_it_out) - - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) - return value - - deferred.addBoth(convert_cancelled) + CancelledError Failure into a defer.TimeoutError. - def cancel_timeout(result): - # stop the pending call to cancel the deferred if it's been fired - if delayed_call.active(): - delayed_call.cancel() - return result - - deferred.addBoth(cancel_timeout) - - -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value - - -def timeout_no_seriously(deferred, timeout, reactor): - """The in build twisted deferred addTimeout (and the method above) - completely fail to time things out under some unknown circumstances. - - Lets try a different way of timing things out and maybe that will make - things work?! - - TODO: Kill this with fire. + Returns: + Deferred """ new_d = defer.Deferred() @@ -457,16 +415,20 @@ def timeout_no_seriously(deferred, timeout, reactor): def time_it_out(): timed_out[0] = True - if not new_d.called: - new_d.errback(DeferredTimeoutError(timeout, "Deferred")) + try: + deferred.cancel() + except: # noqa: E722, if we throw any exception it'll break time outs + logger.exception("Canceller failed during timeout") - deferred.cancel() + if not new_d.called: + new_d.errback(defer.TimeoutError(timeout, "Deferred")) delayed_call = reactor.callLater(timeout, time_it_out) def convert_cancelled(value): if timed_out[0]: - return _cancelled_to_timed_out_error(value, timeout) + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) return value deferred.addBoth(convert_cancelled) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 7b065b195e..f37d5bec08 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import six @@ -20,6 +21,8 @@ from six.moves import intern from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily +logger = logging.getLogger(__name__) + CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) @@ -76,16 +79,20 @@ def register_cache(cache_type, cache_name, cache): return [] def collect(self): - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) + try: + if cache_type == "response_cache": + response_cache_size.labels(cache_name).set(len(cache)) + response_cache_hits.labels(cache_name).set(self.hits) + response_cache_evicted.labels(cache_name).set(self.evicted_size) + response_cache_total.labels(cache_name).set(self.hits + self.misses) + else: + cache_size.labels(cache_name).set(len(cache)) + cache_hits.labels(cache_name).set(self.hits) + cache_evicted.labels(cache_name).set(self.evicted_size) + cache_total.labels(cache_name).set(self.hits + self.misses) + except Exception as e: + logger.warn("Error calculating metrics for %s: %s", cache_name, e) + raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ce85b2ae11..f369780277 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -16,12 +16,17 @@ import logging from collections import OrderedDict +from six import iteritems, itervalues + from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache logger = logging.getLogger(__name__) +SENTINEL = object() + + class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, reset_expiry_on_get=False, iterable=False): @@ -54,11 +59,8 @@ class ExpiringCache(object): self.iterable = iterable - self._size_estimate = 0 - self.metrics = register_cache("expiring", cache_name, self) - def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire return @@ -75,16 +77,11 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - if self.iterable: - self._size_estimate += len(value) - # Evict if there are now too many items while self._max_len and len(self) > self._max_len: _key, value = self._cache.popitem(last=False) if self.iterable: - removed_len = len(value.value) - self.metrics.inc_evictions(removed_len) - self._size_estimate -= removed_len + self.metrics.inc_evictions(len(value.value)) else: self.metrics.inc_evictions() @@ -101,6 +98,21 @@ class ExpiringCache(object): return entry.value + def pop(self, key, default=SENTINEL): + """Removes and returns the value with the given key from the cache. + + If the key isn't in the cache then `default` will be returned if + specified, otherwise `KeyError` will get raised. + + Identical functionality to `dict.pop(..)`. + """ + + value = self._cache.pop(key, default) + if value is SENTINEL: + raise KeyError(key) + + return value + def __contains__(self, key): return key in self._cache @@ -128,14 +140,16 @@ class ExpiringCache(object): keys_to_delete = set() - for key, cache_entry in self._cache.items(): + for key, cache_entry in iteritems(self._cache): if now - cache_entry.time > self._expiry_ms: keys_to_delete.add(key) for k in keys_to_delete: value = self._cache.pop(k) if self.iterable: - self._size_estimate -= len(value.value) + self.metrics.inc_evictions(len(value.value)) + else: + self.metrics.inc_evictions() logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -144,12 +158,14 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return self._size_estimate + return sum(len(entry.value) for entry in itervalues(self._cache)) else: return len(self._cache) class _CacheEntry(object): + __slots__ = ["time", "value"] + def __init__(self, time, value): self.time = time self.value = value diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index f2bde74dc5..625aedc940 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -15,6 +15,8 @@ import logging +from six import integer_types + from sortedcontainers import SortedDict from synapse.util import caches @@ -47,7 +49,7 @@ class StreamChangeCache(object): def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) is int or type(stream_pos) is long + assert type(stream_pos) in integer_types if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a0c2d37610..4c6e92beb8 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works. import logging import threading -from twisted.internet import defer +from twisted.internet import defer, threads logger = logging.getLogger(__name__) @@ -200,7 +200,7 @@ class LoggingContext(object): sentinel = Sentinel() - def __init__(self, name=None, parent_context=None): + def __init__(self, name=None, parent_context=None, request=None): self.previous_context = LoggingContext.current_context() self.name = name @@ -218,6 +218,13 @@ class LoggingContext(object): self.parent_context = parent_context + if self.parent_context is not None: + self.parent_context.copy_to(self) + + if request is not None: + # the request param overrides the request from the parent context + self.request = request + def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -256,9 +263,6 @@ class LoggingContext(object): ) self.alive = True - if self.parent_context is not None: - self.parent_context.copy_to(self) - return self def __exit__(self, type, value, traceback): @@ -439,6 +443,35 @@ class PreserveLoggingContext(object): ) +def nested_logging_context(suffix, parent_context=None): + """Creates a new logging context as a child of another. + + The nested logging context will have a 'request' made up of the parent context's + request, plus the given suffix. + + CPU/db usage stats will be added to the parent context's on exit. + + Normal usage looks like: + + with nested_logging_context(suffix): + # ... do stuff + + Args: + suffix (str): suffix to add to the parent context's 'request'. + parent_context (LoggingContext|None): parent context. Will use the current context + if None. + + Returns: + LoggingContext: new logging context. + """ + if parent_context is None: + parent_context = LoggingContext.current_context() + return LoggingContext( + parent_context=parent_context, + request=parent_context.request + "-" + suffix, + ) + + def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" def g(*args, **kwargs): @@ -529,58 +562,76 @@ def _set_context_cb(result, context): return result -# modules to ignore in `logcontext_tracer` -_to_ignore = [ - "synapse.util.logcontext", - "synapse.http.server", - "synapse.storage._base", - "synapse.util.async_helpers", -] +def defer_to_thread(reactor, f, *args, **kwargs): + """ + Calls the function `f` using a thread from the reactor's default threadpool and + returns the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked, and whose threadpool we should use for the + function. -def logcontext_tracer(frame, event, arg): - """A tracer that logs whenever a logcontext "unexpectedly" changes within - a function. Probably inaccurate. + Normally this will be hs.get_reactor(). - Use by calling `sys.settrace(logcontext_tracer)` in the main thread. + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. """ - if event == 'call': - name = frame.f_globals["__name__"] - if name.startswith("synapse"): - if name == "synapse.util.logcontext": - if frame.f_code.co_name in ["__enter__", "__exit__"]: - tracer = frame.f_back.f_trace - if tracer: - tracer.just_changed = True - - tracer = frame.f_trace - if tracer: - return tracer - - if not any(name.startswith(ig) for ig in _to_ignore): - return LineTracer() - - -class LineTracer(object): - __slots__ = ["context", "just_changed"] - - def __init__(self): - self.context = LoggingContext.current_context() - self.just_changed = False - - def __call__(self, frame, event, arg): - if event in 'line': - if self.just_changed: - self.context = LoggingContext.current_context() - self.just_changed = False - else: - c = LoggingContext.current_context() - if c != self.context: - logger.info( - "Context changed! %s -> %s, %s, %s", - self.context, c, - frame.f_code.co_filename, frame.f_lineno - ) - self.context = c + return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) - return self + +def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): + """ + A wrapper for twisted.internet.threads.deferToThreadpool, which handles + logcontexts correctly. + + Calls the function `f` using a thread from the given threadpool and returns + the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked. Normally this will be hs.get_reactor(). + + threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for + running `f`. Normally this will be hs.get_reactor().getThreadPool(). + + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. + """ + logcontext = LoggingContext.current_context() + + def g(): + with LoggingContext(parent_context=logcontext): + return f(*args, **kwargs) + + return make_deferred_yieldable( + threads.deferToThreadPool(reactor, threadpool, g) + ) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 8d0f2a8918..9cb7e9c9ab 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -70,6 +70,8 @@ def manhole(username, password, globals): Returns: twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ + if not isinstance(password, bytes): + password = password.encode('ascii') checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( **{username: password} @@ -82,7 +84,7 @@ def manhole(username, password, globals): ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) return factory diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 8a3a06fd74..26cce7d197 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -188,7 +188,7 @@ class RetryDestinationLimiter(object): else: self.retry_interval = self.min_retry_interval - logger.debug( + logger.info( "Connection to %s was unsuccessful (%s(%s)); backoff now %i", self.destination, exc_type, exc_val, self.retry_interval ) diff --git a/synapse/visibility.py b/synapse/visibility.py index d4680863d3..0281a7c919 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, ) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), - types=types, + state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( @@ -219,7 +220,7 @@ def filter_events_for_server(store, server_name, events): # Whatever else we do, we need to check for senders which have requested # erasure of their data. erased_senders = yield store.are_users_erased( - e.sender for e in events, + (e.sender for e in events), ) def redact_disallowed(event, state): @@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events): # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), + state_filter=StateFilter.from_types( + types=((EventTypes.RoomHistoryVisibility, ""),), ) ) @@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events): # of the history vis and membership state at those events. event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), + state_filter=StateFilter.from_types( + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ), ) ) @@ -324,14 +327,13 @@ def filter_events_for_server(store, server_name, events): # server's domain. # # event_to_state_ids contains lots of duplicates, so it turns out to be - # cheaper to build a complete set of unique - # ((type, state_key), event_id) tuples, and then filter out the ones we - # don't want. + # cheaper to build a complete event_id => (type, state_key) dict, and then + # filter out the ones we don't want # - state_key_to_event_id_set = { - e + event_id_to_state_key = { + event_id: key for key_to_eid in itervalues(event_to_state_ids) - for e in key_to_eid.items() + for key, event_id in iteritems(key_to_eid) } def include(typ, state_key): @@ -346,7 +348,7 @@ def filter_events_for_server(store, server_name, events): event_map = yield store.get_events([ e_id - for key, e_id in state_key_to_event_id_set + for e_id, key in iteritems(event_id_to_state_key) if include(key[0], key[1]) ]) |