diff options
Diffstat (limited to 'synapse')
391 files changed, 8907 insertions, 6279 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index ee3313a41c..ec16f54a49 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.4.1" +__version__ = "1.5.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cd347fbe1b..5d0b7d2801 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,13 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes +from synapse.api.constants import ( + EventTypes, + JoinRules, + LimitBlockingTypes, + Membership, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -491,7 +497,7 @@ class Auth(object): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token.") + logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender return defer.succeed(service) @@ -726,7 +732,7 @@ class Auth(object): self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type, + limit_type=LimitBlockingTypes.HS_DISABLED, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -759,5 +765,5 @@ class Auth(object): "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user", + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 60e99e4663..49c4b85054 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -131,3 +131,17 @@ class RelationTypes(object): ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + + +class LimitBlockingTypes(object): + """Reasons that a server may be blocked""" + + MONTHLY_ACTIVE_USER = "monthly_active_user" + HS_DISABLED = "hs_disabled" + + +class EventContentFields(object): + """Fields found in events' content, regardless of type.""" + + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 + LABELS = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9f06556bd2..bec13f08d8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -20,6 +20,7 @@ from jsonschema import FormatChecker from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + # Include or exclude events with the provided labels. + # cf https://github.com/matrix-org/matrix-doc/pull/2326 + "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, + "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, } @@ -259,6 +264,9 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + self.labels = self.filter_json.get("org.matrix.labels", None) + self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + def filters_all_types(self): return "*" in self.not_types @@ -282,6 +290,7 @@ class Filter(object): room_id = None ev_type = "m.presence" contains_url = False + labels = [] else: sender = event.get("sender", None) if not sender: @@ -300,10 +309,11 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) + labels = content.get(EventContentFields.LABELS, []) - return self.check_fields(room_id, sender, ev_type, contains_url) + return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, contains_url): + def check_fields(self, room_id, sender, event_type, labels, contains_url): """Checks whether the filter matches the given event fields. Returns: @@ -313,6 +323,7 @@ class Filter(object): "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(event_type, v), + "labels": lambda v: v in labels, } for name, match_func in literal_keys.items(): diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index d877c77834..a01bac2997 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses): bind_addresses (list): Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: - logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") + logger.warning( + "Failed to listen on 0.0.0.0, continuing because listening on [::]" + ) else: raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 767b87d2db..02b900f382 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -94,7 +94,7 @@ class AppserviceServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -103,7 +103,7 @@ class AppserviceServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dbcc414c42..dadb487d5f 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index c67fe69a50..d110599a35 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -56,8 +56,8 @@ from synapse.rest.client.v1.room import ( RoomStateEventRestServlet, ) from synapse.server import HomeServer +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 1ef027a88c..418c086254 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 04fbb407af..139221ad34 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 9504bfbc70..e647459d0e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index eb54f56853..73e2c29d06 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -19,12 +19,13 @@ from __future__ import print_function import gc import logging +import math import os +import resource import sys from six import iteritems -import psutil from prometheus_client import Gauge from twisted.application import service @@ -282,7 +283,7 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -291,7 +292,7 @@ class SynapseHomeServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) def run_startup_checks(self, db_conn, database_engine): all_users_native = are_all_users_on_domain( @@ -471,6 +472,87 @@ class SynapseService(service.Service): return self._port.stopListening() +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + + +@defer.inlineCallbacks +def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = yield hs.get_datastore().count_all_users() + + total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + for name, count in iteritems(daily_user_type_results): + stats["daily_user_type_" + name] = count + + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in iteritems(r30_results): + stats["r30_users_" + name] = count + + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + # + # Performance statistics + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # Database version + # + + stats["database_engine"] = hs.get_datastore().database_engine_name + stats["database_server_version"] = hs.get_datastore().get_server_version() + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + yield hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -497,91 +579,19 @@ def run(hs): reactor.run = profile(reactor.run) clock = hs.get_clock() - start_time = clock.time() stats = {} - # Contains the list of processes we will be monitoring - # currently either 0 or 1 - stats_process = [] + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))) + ) def start_phone_stats_home(): - return run_as_background_process("phone_stats_home", phone_stats_home) - - @defer.inlineCallbacks - def phone_stats_home(): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - start_time) - if uptime < 0: - uptime = 0 - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = yield hs.get_datastore().count_all_users() - - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): - stats["daily_user_type_" + name] = count - - room_count = yield hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats[ - "daily_active_rooms" - ] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() - - r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): - stats["r30_users_" + name] = count - - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size - - if len(stats_process) > 0: - stats["memory_rss"] = 0 - stats["cpu_average"] = 0 - for process in stats_process: - stats["memory_rss"] += process.memory_info().rss - stats["cpu_average"] += int(process.cpu_percent(interval=None)) - - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() - logger.info( - "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats ) - try: - yield hs.get_simple_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warn("Error reporting stats: %s", e) - - def performance_stats_init(): - try: - process = psutil.Process() - # Ensure we can fetch both, and make the initial request for cpu_percent - # so the next request will use this as the initial point. - process.memory_info().rss - process.cpu_percent(interval=None) - logger.info("report_stats can use psutil") - stats_process.append(process) - except (AttributeError): - logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2ac783ffa3..2c6dd3ef02 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -39,8 +39,8 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer +from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.engines import create_engine -from synapse.storage.media_repository import MediaRepositoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d84732ee3c..01a5ffc363 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,7 +114,7 @@ class PusherServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -123,7 +123,7 @@ class PusherServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 473026fce5..b14da09f47 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -54,8 +54,8 @@ from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer +from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index e01afb39f2..6cb100319f 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -42,8 +42,8 @@ from synapse.replication.tcp.streams.events import ( ) from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 33b3579425..aea3985a5f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -94,7 +94,9 @@ class ApplicationService(object): ip_range_whitelist=None, ): self.token = token - self.url = url + self.url = ( + url.rstrip("/") if isinstance(url, str) else None + ) # url must not end with a slash self.hs_token = hs_token self.sender = sender self.server_name = hostname diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 9b4682222d..e77d3387ff 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -48,7 +48,7 @@ class AppServiceConfig(Config): # Uncomment to enable tracking of application service IP addresses. Implicitly # enables MAU tracking for application service users. # - #track_appservice_user_ips: True + #track_appservice_user_ips: true """ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 62c4c44d60..aec9c4bbce 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -62,11 +62,11 @@ DEFAULT_CONFIG = """\ # body: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# send_server_notice_to_guests: True +# send_server_notice_to_guests: true # block_events_error: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# require_at_registration: False +# require_at_registration: false # policy_name: Privacy Policy # """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 658897a77e..39e7a1dddb 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -304,13 +304,13 @@ class EmailConfig(Config): # smtp_port: 25 # SSL: 465, STARTTLS: 587 # smtp_user: "exampleusername" # smtp_pass: "examplepassword" - # require_transport_security: False + # require_transport_security: false # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" # app_name: Matrix # # # Enable email notifications by default # # - # notif_for_new_users: True + # notif_for_new_users: true # # # Defining a custom URL for Riot is only needed if email notifications # # should contain links to a self-hosted installation of Riot; when set diff --git a/synapse/config/key.py b/synapse/config/key.py index ec5d430afb..52ff1b2621 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -125,7 +125,7 @@ class KeyConfig(Config): # if neither trusted_key_servers nor perspectives are given, use the default. if "perspectives" not in config and "trusted_key_servers" not in config: - logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) + logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) key_servers = [{"server_name": "matrix.org"}] else: key_servers = config.get("trusted_key_servers", []) @@ -156,7 +156,7 @@ class KeyConfig(Config): if not self.macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. - logger.warn("Config is missing macaroon_secret_key") + logger.warning("Config is missing macaroon_secret_key") seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index be92e33f93..75bb904718 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") if not log_config: - logger.warn("Reloaded a blank config?") + logger.warning("Reloaded a blank config?") logging.config.dictConfig(log_config) @@ -234,8 +234,8 @@ def setup_logging( # make sure that the first thing we log is a thing we can grep backwards # for - logging.warn("***** STARTING SERVER *****") - logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.warning("***** STARTING SERVER *****") + logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) return logger diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 282a43bddb..22538153e1 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -70,7 +70,7 @@ class MetricsConfig(Config): # Enable collection and rendering of performance metrics # - #enable_metrics: False + #enable_metrics: false # Enable sentry integration # NOTE: While attempts are made to ensure that the logs don't contain diff --git a/synapse/config/registration.py b/synapse/config/registration.py index b3e3e6dda2..1f6dac69da 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -180,7 +180,7 @@ class RegistrationConfig(Config): # where d is equal to 10%% of the validity period. # #account_validity: - # enabled: True + # enabled: true # period: 6w # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" @@ -300,7 +300,7 @@ class RegistrationConfig(Config): # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c407e13680..c5ea2d43a1 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -176,7 +176,7 @@ class SAML2Config(Config): # - url: https://our_idp/metadata.xml # # # By default, the user has to go to our login page first. If you'd like - # # to allow IdP-initiated login, set 'allow_unsolicited: True' in a + # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a # # 'service.sp' section: # # # #service: diff --git a/synapse/config/server.py b/synapse/config/server.py index 26e6d84c09..9015ca5f87 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -171,6 +171,7 @@ class ServerConfig(Config): ) self.mau_trial_days = config.get("mau_trial_days", 0) + self.mau_limit_alerting = config.get("mau_limit_alerting", True) # How long to keep redacted events in the database in unredacted form # before redacting them. @@ -192,7 +193,6 @@ class ServerConfig(Config): # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints @@ -532,7 +532,7 @@ class ServerConfig(Config): # Whether room invites to users on this server should be blocked # (except those sent by local server admins). The default is False. # - #block_non_admin_invites: True + #block_non_admin_invites: true # Room searching # @@ -673,9 +673,8 @@ class ServerConfig(Config): # Global blocking # - #hs_disabled: False + #hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' - #hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -695,15 +694,22 @@ class ServerConfig(Config): # sign up in a short space of time never to return after their initial # session. # - #limit_usage_by_mau: False + # 'mau_limit_alerting' is a means of limiting client side alerting + # should the mau limit be reached. This is useful for small instances + # where the admin has 5 mau seats (say) for 5 specific people and no + # interest increasing the mau limit further. Defaults to True, which + # means that alerting is enabled + # + #limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 + #mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau # is true, this is implied to be true. # - #mau_stats_only: False + #mau_stats_only: false # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. @@ -728,7 +734,7 @@ class ServerConfig(Config): # # Uncomment the below lines to enable: #limit_remote_rooms: - # enabled: True + # enabled: true # complexity: 1.0 # complexity_error: "This room is too complex." diff --git a/synapse/config/tls.py b/synapse/config/tls.py index f06341eb67..2e9e478a2a 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -289,6 +289,9 @@ class TlsConfig(Config): "http://localhost:8009/.well-known/acme-challenge" ) + # flake8 doesn't recognise that variables are used in the below string + _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file + return ( """\ ## TLS ## @@ -451,7 +454,11 @@ class TlsConfig(Config): #tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ - % locals() + # Lowercase the string representation of boolean values + % { + x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1] + for x in locals().items() + } ) def read_tls_certificate(self): diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a68a3068aa..b313bff140 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -56,5 +56,5 @@ class VoipConfig(Config): # connect to arbitrary endpoints without having first signed up for a # valid account (e.g. by passing a CAPTCHA). # - #turn_allow_guests: True + #turn_allow_guests: true """ diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 694fb2c816..ccaa8a9920 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -125,9 +125,11 @@ def compute_event_signature(event_dict, signature_name, signing_key): redact_json = prune_event_dict(event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) - logger.debug("Signing event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signing event: %s", encode_canonical_json(redact_json)) redact_json = sign_json(redact_json, signature_name, signing_key) - logger.debug("Signed event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signed event: %s", encode_canonical_json(redact_json)) return redact_json["signatures"] diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4e91df60e6..ec3243b27b 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) + logger.warning("Trusting event: %s", event.event_id) return if event.type == EventTypes.Create: @@ -493,8 +493,7 @@ def _check_power_levels(event, auth_events): new_level_too_big = new_level is not None and new_level > user_level if old_level_too_big or new_level_too_big: raise AuthError( - 403, - "You don't have permission to add ops level greater " "than your own", + 403, "You don't have permission to add ops level greater than your own" ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index acbcbeeced..64e898f40c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,104 +12,125 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional, Tuple, Union from six import iteritems +import attr from frozendict import frozendict from twisted.internet import defer +from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background -class EventContext(object): +@attr.s(slots=True) +class EventContext: """ + Holds information relevant to persisting an event + Attributes: - state_group (int|None): state group id, if the state has been stored - as a state group. This is usually only None if e.g. the event is - an outlier. - rejected (bool|str): A rejection reason if the event was rejected, else - False - - push_actions (list[(str, list[object])]): list of (user_id, actions) - tuples - - prev_group (int): Previously persisted state group. ``None`` for an - outlier. - delta_ids (dict[(str, str), str]): Delta from ``prev_group``. - (type, state_key) -> event_id. ``None`` for an outlier. - - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - - _current_state_ids (dict[(str, str), str]|None): - The current state map including the current event. None if outlier - or we haven't fetched the state from DB yet. - (type, state_key) -> event_id + rejected: A rejection reason if the event was rejected, else False + + _state_group: The ID of the state group for this event. Note that state events + are persisted with a state group which includes the new event, so this is + effectively the state *after* the event in question. + + For a *rejected* state event, where the state of the rejected event is + ignored, this state_group should never make it into the + event_to_state_groups table. Indeed, inspecting this value for a rejected + state event is almost certainly incorrect. + + For an outlier, where we don't have the state at the event, this will be + None. + + Note that this is a private attribute: it should be accessed via + the ``state_group`` property. + + state_group_before_event: The ID of the state group representing the state + of the room before this event. + + If this is a non-state event, this will be the same as ``state_group``. If + it's a state event, it will be the same as ``prev_group``. + + If ``state_group`` is None (ie, the event is an outlier), + ``state_group_before_event`` will always also be ``None``. + + prev_group: If it is known, ``state_group``'s prev_group. Note that this being + None does not necessarily mean that ``state_group`` does not have + a prev_group! + + If the event is a state event, this is normally the same as ``prev_group``. + + If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` + will always also be ``None``. + + Note that this *not* (necessarily) the state group associated with + ``_prev_state_ids``. + + delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` + and ``state_group``. + + app_service: If this event is being sent by a (local) application service, that + app service. + + _current_state_ids: The room state map, including this event - ie, the state + in ``state_group``. - _prev_state_ids (dict[(str, str), str]|None): - The current state map excluding the current event. None if outlier - or we haven't fetched the state from DB yet. (type, state_key) -> event_id - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet + FIXME: what is this for an outlier? it seems ill-defined. It seems like + it could be either {}, or the state we were given by the remote + server, depending on $THINGS - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. + Note that this is a private attribute: it should be accessed via + ``get_current_state_ids``. _AsyncEventContext impl calculates this + on-demand: it will be None until that happens. - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. + _prev_state_ids: The room state map, excluding this event - ie, the state + in ``state_group_before_event``. For a non-state + event, this will be the same as _current_state_events. - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. - """ + Note that it is a completely different thing to prev_group! - __slots__ = [ - "state_group", - "rejected", - "prev_group", - "delta_ids", - "prev_state_events", - "app_service", - "_current_state_ids", - "_prev_state_ids", - "_prev_state_id", - "_event_type", - "_event_state_key", - "_fetching_state_deferred", - ] - - def __init__(self): - self.prev_state_events = [] - self.rejected = False - self.app_service = None + (type, state_key) -> event_id - @staticmethod - def with_state( - state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None - ): - context = EventContext() + FIXME: again, what is this for an outlier? - # The current state including the current event - context._current_state_ids = current_state_ids - # The current state excluding the current event - context._prev_state_ids = prev_state_ids - context.state_group = state_group + As with _current_state_ids, this is a private attribute. It should be + accessed via get_prev_state_ids. + """ - context._prev_state_id = None - context._event_type = None - context._event_state_key = None - context._fetching_state_deferred = defer.succeed(None) + rejected = attr.ib(default=False, type=Union[bool, str]) + _state_group = attr.ib(default=None, type=Optional[int]) + state_group_before_event = attr.ib(default=None, type=Optional[int]) + prev_group = attr.ib(default=None, type=Optional[int]) + delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + app_service = attr.ib(default=None, type=Optional[ApplicationService]) - # A previously persisted state group and a delta between that - # and this state. - context.prev_group = prev_group - context.delta_ids = delta_ids + _current_state_ids = attr.ib( + default=None, type=Optional[Dict[Tuple[str, str], str]] + ) + _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) - return context + @staticmethod + def with_state( + state_group, + state_group_before_event, + current_state_ids, + prev_state_ids, + prev_group=None, + delta_ids=None, + ): + return EventContext( + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + state_group=state_group, + state_group_before_event=state_group_before_event, + prev_group=prev_group, + delta_ids=delta_ids, + ) @defer.inlineCallbacks def serialize(self, event, store): @@ -137,11 +158,11 @@ class EventContext(object): "prev_state_id": prev_state_id, "event_type": event.type, "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, + "state_group": self._state_group, + "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @@ -157,24 +178,18 @@ class EventContext(object): Returns: EventContext """ - context = EventContext() - - # We use the state_group and prev_state_id stuff to pull the - # current_state_ids out of the DB and construct prev_state_ids. - context._prev_state_id = input["prev_state_id"] - context._event_type = input["event_type"] - context._event_state_key = input["event_state_key"] - - context._current_state_ids = None - context._prev_state_ids = None - context._fetching_state_deferred = None - - context.state_group = input["state_group"] - context.prev_group = input["prev_group"] - context.delta_ids = _decode_state_dict(input["delta_ids"]) - - context.rejected = input["rejected"] - context.prev_state_events = input["prev_state_events"] + context = _AsyncEventContextImpl( + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + prev_state_id=input["prev_state_id"], + event_type=input["event_type"], + event_state_key=input["event_state_key"], + state_group=input["state_group"], + state_group_before_event=input["state_group_before_event"], + prev_group=input["prev_group"], + delta_ids=_decode_state_dict(input["delta_ids"]), + rejected=input["rejected"], + ) app_service_id = input["app_service_id"] if app_service_id: @@ -182,29 +197,52 @@ class EventContext(object): return context + @property + def state_group(self) -> Optional[int]: + """The ID of the state group for this event. + + Note that state events are persisted with a state group which includes the new + event, so this is effectively the state *after* the event in question. + + For an outlier, where we don't have the state at the event, this will be None. + + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. Accessing this property will raise an exception + if ``rejected`` is set. + """ + if self.rejected: + raise RuntimeError("Attempt to access state_group of rejected event") + + return self._state_group + @defer.inlineCallbacks def get_current_state_ids(self, store): - """Gets the current state IDs + """ + Gets the room state map, including this event - ie, the state in ``state_group`` + + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. This method will raise an exception if + ``rejected`` is set. Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group is None, which happens when the associated event is an outlier. + Maps a (type, state_key) to the event ID of the state event matching this tuple. """ + if self.rejected: + raise RuntimeError("Attempt to access state_ids of rejected event") - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks def get_prev_state_ids(self, store): - """Gets the prev state IDs + """ + Gets the room state map, excluding this event. + + For a non-state event, this will be the same as get_current_state_ids(). Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group @@ -212,27 +250,64 @@ class EventContext(object): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): """Gets the current state IDs if we have them already cached. + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. This method will raise an exception if + ``rejected`` is set. + Returns: dict[(str, str), str]|None: Returns None if we haven't cached the state or if state_group is None, which happens when the associated event is an outlier. """ + if self.rejected: + raise RuntimeError("Attempt to access state_ids of rejected event") return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids @@ -250,27 +325,6 @@ class EventContext(object): else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 129771f183..5a907718d6 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +from synapse.spam_checker_api import SpamCheckerApi + class SpamChecker(object): def __init__(self, hs): @@ -26,7 +31,14 @@ class SpamChecker(object): pass if module is not None: - self.spam_checker = module(config=config) + # Older spam checkers don't accept the `api` argument, so we + # try and detect support. + spam_args = inspect.getfullargspec(module) + if "api" in spam_args.args: + api = SpamCheckerApi(hs) + self.spam_checker = module(config=config, api=api) + else: + self.spam_checker = module(config=config) def check_event_for_spam(self, event): """Checks if a given event is considered "spammy" by this server. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 5a1e23a145..0e22183280 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -102,7 +102,7 @@ class FederationBase(object): pass if not res: - logger.warn( + logger.warning( "Failed to find copy of %s with valid signature", pdu.event_id ) @@ -173,7 +173,7 @@ class FederationBase(object): return redacted_event if self.spam_checker.check_event_for_spam(pdu): - logger.warn( + logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json(), @@ -185,7 +185,7 @@ class FederationBase(object): def errback(failure, pdu): failure.trap(SynapseError) with PreserveLoggingContext(ctx): - logger.warn( + logger.warning( "Signature check failed for %s: %s", pdu.event_id, failure.getErrorMessage(), @@ -278,9 +278,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): pdu_to_check.sender_domain, e.getErrorMessage(), ) - # XX not really sure if these are the right codes, but they are what - # we've done for ages - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_sender, more_deferreds): d.addErrback(sender_err, p) @@ -314,8 +312,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): "event id %s: unable to verify signature for event id domain: %s" % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) - # XX as above: not really sure if these are the right codes - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_event_id, more_deferreds): d.addErrback(event_err, p) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5b22a39b7f..545d719652 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -196,7 +196,7 @@ class FederationClient(FederationBase): dest, room_id, extremities, limit ) - logger.debug("backfill transaction_data=%s", repr(transaction_data)) + logger.debug("backfill transaction_data=%r", transaction_data) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) @@ -522,12 +522,12 @@ class FederationClient(FederationBase): res = yield callback(destination) return res except InvalidResponseError as e: - logger.warn("Failed to %s via %s: %s", description, destination, e) + logger.warning("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: - logger.warn( + logger.warning( "Failed to %s via %s: %i %s", description, destination, @@ -535,7 +535,9 @@ class FederationClient(FederationBase): e.args[0], ) except Exception: - logger.warn("Failed to %s via %s", description, destination, exc_info=1) + logger.warning( + "Failed to %s via %s", description, destination, exc_info=1 + ) raise SynapseError(502, "Failed to %s via any server" % (description,)) @@ -553,7 +555,7 @@ class FederationClient(FederationBase): Note that this does not append any events to any graphs. Args: - destinations (str): Candidate homeservers which are probably + destinations (Iterable[str]): Candidate homeservers which are probably participating in the room. room_id (str): The room in which the event will happen. user_id (str): The user whose membership is being evented. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 21e52c9695..d942d77a72 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -151,7 +145,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - # Reject if PDU count > 50 and EDU count > 100 + # Reject if PDU count > 50 or EDU count > 100 if len(transaction.pdus) > 50 or ( hasattr(transaction, "edus") and len(transaction.edus) > 100 ): @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,13 +215,12 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn("Ignoring PDUs for room %s from banned server", room_id) + logger.warning("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -237,10 +230,10 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) + logger.warning("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: f = failure.Failure() @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,60 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: - logger.warn("Room version %s not in %s", room_version, supported_versions) + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,24 +354,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -398,45 +386,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - yield self.handler.on_send_leave_request(origin, pdu) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -455,22 +442,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -496,16 +483,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -529,14 +514,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -546,7 +529,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -579,8 +562,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -633,37 +615,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -673,13 +652,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -702,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event): # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warn("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignorning non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -716,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event): # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warn("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignorning non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -726,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event): # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warn("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignorning non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -740,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn( + logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False @@ -792,15 +771,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: - logger.warn("No handler registered for EDU type %s", edu_type) + logger.warning("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: @@ -809,7 +787,7 @@ class FederationHandlerRegistry(object): def on_query(self, query_type, args): handler = self.query_handlers.get(query_type) if not handler: - logger.warn("No handler registered for query type %s", query_type) + logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) @@ -833,7 +811,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -841,17 +819,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type, args): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 454456a52d..ced4925a98 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -36,6 +36,8 @@ from six import iteritems from sortedcontainers import SortedDict +from twisted.internet import defer + from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure @@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object): receipt (synapse.types.ReadReceipt): """ # nothing to do here: the replication listener will handle it. - pass + return defer.succeed(None) def send_presence(self, states): """As per FederationSender diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index fad980b893..a5b36b1827 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -30,7 +30,7 @@ from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import UserPresenceState +from synapse.storage.presence import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -192,15 +192,16 @@ class PerDestinationQueue(object): # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = ( - yield self._get_device_update_edus(limit) + device_update_edus, dev_list_id = yield self._get_device_update_edus( + limit ) limit -= len(device_update_edus) - to_device_edus, device_stream_id = ( - yield self._get_to_device_message_edus(limit) - ) + ( + to_device_edus, + device_stream_id, + ) = yield self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -359,20 +360,20 @@ class PerDestinationQueue(object): last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_devices_by_remote( + now_stream_id, results = yield self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type=edu_type, content=content, ) - for content in results + for (edu_type, content) in results ] - assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" return (edus, now_stream_id) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5b6c79c51a..67b3e1ab6e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -146,7 +146,7 @@ class TransactionManager(object): if code == 200: for e_id, r in response.get("pdus", {}).items(): if "error" in r: - logger.warn( + logger.warning( "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, @@ -155,7 +155,7 @@ class TransactionManager(object): ) else: for p in pdus: - logger.warn( + logger.warning( "TX [%s] {%s} Failed to send event %s", destination, txn_id, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 7b18408144..920fa86853 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -122,10 +122,10 @@ class TransportLayerClient(object): Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", + "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", destination, room_id, - repr(event_tuples), + event_tuples, str(limit), ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0f16f21c2d..d6c23f22bd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes): sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: - logger.warn( + logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, @@ -287,10 +287,12 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.warn("authenticate_request failed: missing authentication") + logger.warning( + "authenticate_request failed: missing authentication" + ) raise except Exception as e: - logger.warn("authenticate_request failed: %s", e) + logger.warning("authenticate_request failed: %s", e) raise request_tags = { diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dfd7ae041b..d950a8b246 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -181,7 +181,7 @@ class GroupAttestionRenewer(object): elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - logger.warn( + logger.warning( "Incorrectly trying to do attestations for user: %r in %r", user_id, group_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8f10b6adbb..29e8ffc295 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -488,7 +488,7 @@ class GroupsServerHandler(object): profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) except Exception as e: - logger.warn("Error getting profile for %s: %s", user_id, e) + logger.warning("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 38bc67191c..2d7e6df6e4 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -38,9 +38,10 @@ class AccountDataEventSource(object): {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} ) - account_data, room_account_data = ( - yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) - ) + ( + account_data, + room_account_data, + ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..6407d56f8e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,9 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) + self.storage = hs.get_storage() + self.state_store = self.storage.state + @defer.inlineCallbacks def get_whois(self, user): connections = [] @@ -205,7 +208,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = yield filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +244,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = yield self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3e9b298154..fe62f78e67 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -73,7 +73,10 @@ class ApplicationServicesHandler(object): try: limit = 100 while True: - upper_bound, events = yield self.store.get_new_events_for_appservice( + ( + upper_bound, + events, + ) = yield self.store.get_new_events_for_appservice( self.current_max, limit ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 333eb30625..7a0f54ca24 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -525,7 +525,7 @@ class AuthHandler(BaseHandler): result = None if not user_infos: - logger.warn("Attempted to login as %s but they do not exist", user_id) + logger.warning("Attempted to login as %s but they do not exist", user_id) elif len(user_infos) == 1: # a single match (possibly not exact) result = user_infos.popitem() @@ -534,7 +534,7 @@ class AuthHandler(BaseHandler): result = (user_id, user_infos[user_id]) else: # multiple matches, none of them exact - logger.warn( + logger.warning( "Attempted to login as %s but it matches more than one user " "inexactly: %r", user_id, @@ -728,7 +728,7 @@ class AuthHandler(BaseHandler): result = yield self.validate_hash(password, password_hash) if not result: - logger.warn("Failed password login for user %s", user_id) + logger.warning("Failed password login for user %s", user_id) return None return user_id diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..26ef5e150c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -458,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler): @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - return {"user_id": user_id, "stream_id": stream_id, "devices": devices} + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + + return { + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + "master_key": master_key, + "self_signing_key": self_signing_key, + } @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -656,7 +668,7 @@ class DeviceListUpdater(object): except (NotRetryingDestination, RequestSendFailed, HttpResponseException): # TODO: Remember that we are now out of sync and try again # later - logger.warn("Failed to handle device list update for %s", user_id) + logger.warning("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -694,7 +706,7 @@ class DeviceListUpdater(object): # up on storing the total list of devices and only handle the # delta instead. if len(devices) > 1000: - logger.warn( + logger.warning( "Ignoring device list snapshot for %s as it has >1K devs (%d)", user_id, len(devices), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0043cbea17..73b9e120f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -52,7 +52,7 @@ class DeviceMessageHandler(object): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 526379c6f7..c4632f8984 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler): ignore_backoff=True, ) except CodeMessageException as e: - logging.warn("Error retrieving alias") + logging.warning("Error retrieving alias") if e.code == 404: result = None else: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index be675197f1..f09a0b73c8 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,12 +19,15 @@ import logging from six import iteritems +import attr from canonicaljson import encode_canonical_json, json +from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json +from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.errors import CodeMessageException, Codes, SynapseError +from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( @@ -33,6 +36,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import unwrapFirstError +from synapse.util.async_helpers import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -46,10 +51,19 @@ class E2eKeysHandler(object): self.is_mine = hs.is_mine self.clock = hs.get_clock() + self._edu_updater = SigningKeyEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "org.matrix.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - hs.get_federation_registry().register_query_handler( + federation_registry.register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -116,9 +130,10 @@ class E2eKeysHandler(object): else: query_list.append((user_id, None)) - user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache(query_list) - ) + ( + user_ids_not_in_cache, + remote_results, + ) = yield self.store.get_user_devices_from_cache(query_list) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) for device_id, device in iteritems(devices): @@ -204,13 +219,15 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - for user_id, key in remote_result["master_keys"].items(): - if user_id in destination_query: - cross_signing_keys["master_keys"][user_id] = key + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key - for user_id, key in remote_result["self_signing_keys"].items(): - if user_id in destination_query: - cross_signing_keys["self_signing_keys"][user_id] = key + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key except Exception as e: failure = _exception_to_failure(e) @@ -248,7 +265,7 @@ class E2eKeysHandler(object): Returns: defer.Deferred[dict[str, dict[str, dict]]]: map from - (master|self_signing|user_signing) -> user_id -> key + (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ master_keys = {} self_signing_keys = {} @@ -340,7 +357,16 @@ class E2eKeysHandler(object): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - return {"device_keys": res} + ret = {"device_keys": res} + + # add in the cross-signing keys + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + device_keys_query, None + ) + + ret.update(cross_signing_keys) + + return ret @trace @defer.inlineCallbacks @@ -602,6 +628,349 @@ class E2eKeysHandler(object): return {} + @defer.inlineCallbacks + def upload_signatures_for_device_keys(self, user_id, signatures): + """Upload device signatures for cross-signing + + Args: + user_id (string): the user uploading the signatures + signatures (dict[string, dict[string, dict]]): map of users to + devices to signed keys. This is the submission from the user; an + exception will be raised if it is malformed. + Returns: + dict: response to be sent back to the client. The response will have + a "failures" key, which will be a dict mapping users to devices + to errors for the signatures that failed. + Raises: + SynapseError: if the signatures dict is not valid. + """ + failures = {} + + # signatures to be stored. Each item will be a SignatureListItem + signature_list = [] + + # split between checking signatures for own user and signatures for + # other users, since we verify them with different keys + self_signatures = signatures.get(user_id, {}) + other_signatures = {k: v for k, v in signatures.items() if k != user_id} + + self_signature_list, self_failures = yield self._process_self_signatures( + user_id, self_signatures + ) + signature_list.extend(self_signature_list) + failures.update(self_failures) + + other_signature_list, other_failures = yield self._process_other_signatures( + user_id, other_signatures + ) + signature_list.extend(other_signature_list) + failures.update(other_failures) + + # store the signature, and send the appropriate notifications for sync + logger.debug("upload signature failures: %r", failures) + yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + + self_device_ids = [item.target_device_id for item in self_signature_list] + if self_device_ids: + yield self.device_handler.notify_device_update(user_id, self_device_ids) + signed_users = [item.target_user_id for item in other_signature_list] + if signed_users: + yield self.device_handler.notify_user_signature_update( + user_id, signed_users + ) + + return {"failures": failures} + + @defer.inlineCallbacks + def _process_self_signatures(self, user_id, signatures): + """Process uploaded signatures of the user's own keys. + + Signatures of the user's own keys from this API come in two forms: + - signatures of the user's devices by the user's self-signing key, + - signatures of the user's master key by the user's devices. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of devices to signed keys + + Returns: + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to store, and a map of users to devices to failure + reasons + + Raises: + SynapseError: if the input is malformed + """ + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + if not isinstance(signatures, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + try: + # get our self-signing key to verify the signatures + ( + _, + self_signing_key_id, + self_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + + # get our master key, since we may have received a signature of it. + # We need to fetch it here so that we know what its key ID is, so + # that we can check if a signature that was sent is a signature of + # the master key or of a device + ( + master_key, + _, + master_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + + # fetch our stored devices. This is used to 1. verify + # signatures on the master key, and 2. to compare with what + # was sent if the device was signed + devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + + if user_id not in devices: + raise NotFoundError("No device keys found") + + devices = devices[user_id] + except SynapseError as e: + failure = _exception_to_failure(e) + failures[user_id] = {device: failure for device in signatures.keys()} + return signature_list, failures + + for device_id, device in signatures.items(): + # make sure submitted data is in the right form + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + try: + if "signatures" not in device or user_id not in device["signatures"]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + if device_id == master_verify_key.version: + # The signature is of the master key. This needs to be + # handled differently from signatures of normal devices. + master_key_signature_list = self._check_master_key_signature( + user_id, device_id, device, master_key, devices + ) + signature_list.extend(master_key_signature_list) + continue + + # at this point, we have a device that should be signed + # by the self-signing key + if self_signing_key_id not in device["signatures"][user_id]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + try: + stored_device = devices[device_id] + except KeyError: + raise NotFoundError("Unknown device") + if self_signing_key_id in stored_device.get("signatures", {}).get( + user_id, {} + ): + # we already have a signature on this device, so we + # can skip it, since it should be exactly the same + continue + + _check_device_signature( + user_id, self_signing_verify_key, device, stored_device + ) + + signature = device["signatures"][user_id][self_signing_key_id] + signature_list.append( + SignatureListItem( + self_signing_key_id, user_id, device_id, signature + ) + ) + except SynapseError as e: + failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e) + + return signature_list, failures + + def _check_master_key_signature( + self, user_id, master_key_id, signed_master_key, stored_master_key, devices + ): + """Check signatures of a user's master key made by their devices. + + Args: + user_id (string): the user whose master key is being checked + master_key_id (string): the ID of the user's master key + signed_master_key (dict): the user's signed master key that was uploaded + stored_master_key (dict): our previously-stored copy of the user's master key + devices (iterable(dict)): the user's devices + + Returns: + list[SignatureListItem]: a list of signatures to store + + Raises: + SynapseError: if a signature is invalid + """ + # for each device that signed the master key, check the signature. + master_key_signature_list = [] + sigs = signed_master_key["signatures"] + for signing_key_id, signature in sigs[user_id].items(): + _, signing_device_id = signing_key_id.split(":", 1) + if ( + signing_device_id not in devices + or signing_key_id not in devices[signing_device_id]["keys"] + ): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"][signing_key_id] + verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey)) + _check_device_signature( + user_id, verify_key, signed_master_key, stored_master_key + ) + + master_key_signature_list.append( + SignatureListItem(signing_key_id, user_id, master_key_id, signature) + ) + + return master_key_signature_list + + @defer.inlineCallbacks + def _process_other_signatures(self, user_id, signatures): + """Process uploaded signatures of other users' keys. These will be the + target user's master keys, signed by the uploading user's user-signing + key. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of users to devices to signed keys + + Returns: + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to store, and a map of users to devices to failure + reasons + + Raises: + SynapseError: if the input is malformed + """ + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + try: + # get our user-signing key to verify the signatures + ( + user_signing_key, + user_signing_key_id, + user_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + except SynapseError as e: + failure = _exception_to_failure(e) + for user, devicemap in signatures.items(): + failures[user] = {device_id: failure for device_id in devicemap.keys()} + return signature_list, failures + + for target_user, devicemap in signatures.items(): + # make sure submitted data is in the right form + if not isinstance(devicemap, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + for device in devicemap.values(): + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + device_id = None + try: + # get the target user's master key, to make sure it matches + # what was sent + ( + master_key, + master_key_id, + _, + ) = yield self._get_e2e_cross_signing_verify_key( + target_user, "master", user_id + ) + + # make sure that the target user's master key is the one that + # was signed (and no others) + device_id = master_key_id.split(":", 1)[1] + if device_id not in devicemap: + logger.debug( + "upload signature: could not find signature for device %s", + device_id, + ) + # set device to None so that the failure gets + # marked on all the signatures + device_id = None + raise NotFoundError("Unknown device") + key = devicemap[device_id] + other_devices = [k for k in devicemap.keys() if k != device_id] + if other_devices: + # other devices were signed -- mark those as failures + logger.debug("upload signature: too many devices specified") + failure = _exception_to_failure(NotFoundError("Unknown device")) + failures[target_user] = { + device: failure for device in other_devices + } + + if user_signing_key_id in master_key.get("signatures", {}).get( + user_id, {} + ): + # we already have the signature, so we can skip it + continue + + _check_device_signature( + user_id, user_signing_verify_key, key, master_key + ) + + signature = key["signatures"][user_id][user_signing_key_id] + signature_list.append( + SignatureListItem( + user_signing_key_id, target_user, device_id, signature + ) + ) + except SynapseError as e: + failure = _exception_to_failure(e) + if device_id is None: + failures[target_user] = { + device_id: failure for device_id in devicemap.keys() + } + else: + failures.setdefault(target_user, {})[device_id] = failure + + return signature_list, failures + + @defer.inlineCallbacks + def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): + """Fetch the cross-signing public key from storage and interpret it. + + Args: + user_id (str): the user whose key should be fetched + key_type (str): the type of key to fetch + from_user_id (str): the user that we are fetching the keys for. + This affects what signatures are fetched. + + Returns: + dict, str, VerifyKey: the raw key data, the key ID, and the + signedjson verify key + + Raises: + NotFoundError: if the key is not found + """ + key = yield self.store.get_e2e_cross_signing_key( + user_id, key_type, from_user_id + ) + if key is None: + logger.debug("no %s key found for %s", key_type, user_id) + raise NotFoundError("No %s key found for %s" % (key_type, user_id)) + key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + def _check_cross_signing_key(key, user_id, key_type, signing_key=None): """Check a cross-signing key uploaded by a user. Performs some basic sanity @@ -630,7 +999,49 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) +def _check_device_signature(user_id, verify_key, signed_device, stored_device): + """Check that a signature on a device or cross-signing key is correct and + matches the copy of the device/key that we have stored. Throws an + exception if an error is detected. + + Args: + user_id (str): the user ID whose signature is being checked + verify_key (VerifyKey): the key to verify the device with + signed_device (dict): the uploaded signed device data + stored_device (dict): our previously stored copy of the device + + Raises: + SynapseError: if the signature was invalid or the sent device is not the + same as the stored device + + """ + + # make sure that the device submitted matches what we have stored + stripped_signed_device = { + k: v for k, v in signed_device.items() if k not in ["signatures", "unsigned"] + } + stripped_stored_device = { + k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"] + } + if stripped_signed_device != stripped_stored_device: + logger.debug( + "upload signatures: key does not match %s vs %s", + signed_device, + stored_device, + ) + raise SynapseError(400, "Key does not match") + + try: + verify_signed_json(signed_device, user_id, verify_key) + except SignatureVerifyException: + logger.debug("invalid signature on key") + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) + + def _exception_to_failure(e): + if isinstance(e, SynapseError): + return {"status": e.code, "errcode": e.errcode, "message": str(e)} + if isinstance(e, CodeMessageException): return {"status": e.code, "message": str(e)} @@ -658,3 +1069,111 @@ def _one_time_keys_match(old_key_json, new_key): new_key_copy.pop("signatures", None) return old_key == new_key_copy + + +@attr.s +class SignatureListItem: + """An item in the signature list as used by upload_signatures_for_device_keys. + """ + + signing_key_id = attr.ib() + target_user_id = attr.ib() + target_device_id = attr.ib() + signature = attr.ib() + + +class SigningKeyEduUpdater(object): + """Handles incoming signing key updates from federation and updates the DB""" + + def __init__(self, hs, e2e_keys_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.e2e_keys_handler = e2e_keys_handler + + self._remote_edu_linearizer = Linearizer(name="remote_signing_key") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="signing_key_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_signing_key_update(self, origin, edu_content): + """Called on incoming signing key update from federation. Responsible for + parsing the EDU and adding to pending updates list. + + Args: + origin (string): the server that sent the EDU + edu_content (dict): the contents of the EDU + """ + + user_id = edu_content.pop("user_id") + master_key = edu_content.pop("master_key", None) + self_signing_key = edu_content.pop("self_signing_key", None) + + if get_domain_from_id(user_id) != origin: + logger.warning("Got signing key update edu for %r from %r", user_id, origin) + return + + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + self._pending_updates.setdefault(user_id, []).append( + (master_key, self_signing_key) + ) + + yield self._handle_signing_key_updates(user_id) + + @defer.inlineCallbacks + def _handle_signing_key_updates(self, user_id): + """Actually handle pending updates. + + Args: + user_id (string): the user whose updates we are processing + """ + + device_handler = self.e2e_keys_handler.device_handler + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + device_ids = [] + + logger.info("pending updates: %r", pending_updates) + + for master_key, self_signing_key in pending_updates: + if master_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "master", master_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key( + self_signing_key + ) + device_ids.append(verify_key.version) + + yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..45fe13c62f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -172,7 +176,7 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b4c6c15f9..05dd8d2671 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -109,6 +110,8 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -180,7 +183,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn( + logger.warning( "[%s %s] Received event failed sanity checks", room_id, event_id ) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) @@ -301,7 +304,7 @@ class FederationHandler(BaseHandler): # following. if sent_to_us_directly: - logger.warn( + logger.warning( "[%s %s] Rejecting: failed to fetch %d prev events: %s", room_id, event_id, @@ -324,7 +327,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = yield self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -350,10 +353,11 @@ class FederationHandler(BaseHandler): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ( + remote_state, + got_auth_chain, + ) = yield self.federation_client.get_state_for_room( + origin, room_id, p ) # we want the state *after* p; get_state_for_room returns the @@ -405,7 +409,7 @@ class FederationHandler(BaseHandler): state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: - logger.warn( + logger.warning( "[%s %s] Error attempting to resolve state at missing " "prev_events", room_id, @@ -518,7 +522,9 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) return logger.info( @@ -545,7 +551,7 @@ class FederationHandler(BaseHandler): yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: - logger.warn( + logger.warning( "[%s %s] Received prev_event %s failed history check.", room_id, event_id, @@ -888,7 +894,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -1059,7 +1065,7 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn( + logger.warning( "Rejecting event %s which has %i prev_events", ev.event_id, len(ev.prev_event_ids()), @@ -1067,7 +1073,7 @@ class FederationHandler(BaseHandler): raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn( + logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), @@ -1101,7 +1107,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_invite_join(self, target_hosts, room_id, joinee, content): """ Attempts to join the `joinee` to the room `room_id` via the - server `target_host`. + servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial event that we can fill out and sign. This is then sent to the @@ -1110,6 +1116,15 @@ class FederationHandler(BaseHandler): We suspend processing of any received events from this room until we have finished processing the join. + + Args: + target_hosts (Iterable[str]): List of servers to attempt to join the room with. + + room_id (str): The ID of the room to join. + + joinee (str): The User ID of the joining user. + + content (dict): The event content to use for the join event. """ logger.debug("Joining %s to %s", joinee, room_id) @@ -1169,6 +1184,22 @@ class FederationHandler(BaseHandler): yield self._persist_auth_tree(origin, auth_chain, state, event) + # Check whether this room is the result of an upgrade of a room we already know + # about. If so, migrate over user information + predecessor = yield self.store.get_room_predecessor(room_id) + if not predecessor: + return + old_room_id = predecessor["room_id"] + logger.debug( + "Found predecessor for %s during remote join: %s", room_id, old_room_id + ) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + yield member_handler.transfer_room_state_on_room_upgrade( + old_room_id, room_id + ) + logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -1203,7 +1234,7 @@ class FederationHandler(BaseHandler): with nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: - logger.warn( + logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e ) @@ -1222,7 +1253,6 @@ class FederationHandler(BaseHandler): Returns: Deferred[FrozenEvent] """ - if get_domain_from_id(user_id) != origin: logger.info( "Got /make_join request for user %r from different origin %s, ignoring", @@ -1251,7 +1281,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join %r because %s", event, e) + logger.warning("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( @@ -1280,11 +1310,20 @@ class FederationHandler(BaseHandler): event = pdu logger.debug( - "on_send_join_request: Got event: %s, signatures: %s", + "on_send_join_request from %s: Got event: %s, signatures: %s", + origin, event.event_id, event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_join request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False # Send this event on behalf of the origin server. # @@ -1486,7 +1525,7 @@ class FederationHandler(BaseHandler): room_version, event, context, do_sig_check=False ) except AuthError as e: - logger.warn("Failed to create new leave %r because %s", event, e) + logger.warning("Failed to create new leave %r because %s", event, e) raise e return event @@ -1503,6 +1542,14 @@ class FederationHandler(BaseHandler): event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_leave request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False context = yield self._handle_new_event(origin, event) @@ -1533,7 +1580,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1562,7 +1609,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1590,7 +1637,7 @@ class FederationHandler(BaseHandler): events = yield self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.store, origin, events) + events = yield filter_events_for_server(self.storage, origin, events) return events @@ -1620,7 +1667,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = yield filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1641,7 +1688,11 @@ class FederationHandler(BaseHandler): # hack around with a try/finally instead. success = False try: - if not event.internal_metadata.is_outlier() and not backfilled: + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + ): yield self.action_generator.handle_push_actions_for_event( event, context ) @@ -1772,7 +1823,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn("Rejecting %s because %s", e.event_id, err.msg) + logger.warning("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1825,12 +1876,7 @@ class FederationHandler(BaseHandler): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -1886,7 +1932,7 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = yield self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) @@ -1922,7 +1968,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn("Soft-failing %r because %s", event, e) + logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks @@ -1977,7 +2023,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events + self.storage, origin, missing_events ) return missing_events @@ -1999,12 +2045,12 @@ class FederationHandler(BaseHandler): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2021,8 +2067,10 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2046,7 +2094,7 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2085,7 +2133,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2126,7 +2174,7 @@ class FederationHandler(BaseHandler): if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2135,7 +2183,7 @@ class FederationHandler(BaseHandler): ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2182,10 +2230,12 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2194,14 +2244,16 @@ class FederationHandler(BaseHandler): Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2218,7 +2270,7 @@ class FederationHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -2226,8 +2278,9 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, + state_group_before_event=context.state_group_before_event, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, prev_group=prev_group, @@ -2415,10 +2468,12 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying new third party invite %r because %s", event, e) + logger.warning("Denying new third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: @@ -2471,7 +2526,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying third party invite %r because %s", event, e) + logger.warning("Denying third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) @@ -2479,6 +2534,7 @@ class FederationHandler(BaseHandler): # though the sender isn't a local user. event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) @@ -2648,7 +2704,7 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) else: - max_stream_id = yield self.store.persist_events( + max_stream_id = yield self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 46eb9ee88b..92fecbfc44 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -392,7 +392,7 @@ class GroupsLocalHandler(object): try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: - logger.warn("No profile for user %s: %s", user_id, e) + logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} return {"state": "invite", "user_profile": user_profile} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ba99ddf76d..000fbf090f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler): changed = False if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) - logger.warn("Received %d response while unbinding threepid", e.code) + logger.warning("Received %d response while unbinding threepid", e.code) else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") @@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..81dce96f4b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -126,8 +128,8 @@ class InitialSyncHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + account_data, account_data_by_room = yield self.store.get_account_data_for_user( + user_id ) public_room_ids = yield self.store.get_public_room_ids() @@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = yield filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f8cce8ffe..d682dc2b7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,6 +59,8 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @@ -74,15 +76,16 @@ class MessageHandler(object): Raises: SynapseError if something went wrong. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +138,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -151,9 +154,10 @@ class MessageHandler(object): % (user_id, room_id, at_token), ) else: - membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable(room_id, user_id) - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( @@ -161,7 +165,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] @@ -234,6 +238,7 @@ class EventCreationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -687,7 +692,7 @@ class EventCreationHandler(object): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as err: - logger.warn("Denying new event %r because %s", event, err) + logger.warning("Denying new event %r because %s", event, err) raise err # Ensure that we can round trip before trying to persist in db @@ -868,7 +873,7 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event, context=context ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..260a4351ca 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -125,7 +127,9 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history(room_id, token, delete_local_events) + yield self.storage.purge_events.purge_history( + room_id, token, delete_local_events + ) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: @@ -168,7 +172,7 @@ class PaginationHandler(object): if joined: raise SynapseError(400, "Users are still joined to this room") - await self.store.purge_room(room_id) + await self.storage.purge_events.purge_room(room_id) @defer.inlineCallbacks def get_messages( @@ -210,9 +214,10 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + member_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -255,7 +260,7 @@ class PaginationHandler(object): events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,7 +279,7 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -295,10 +300,8 @@ class PaginationHandler(object): } if state: - chunk["state"] = ( - yield self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event - ) + chunk["state"] = yield self._event_serializer.serialize_events( + state, time_now, as_client_event=as_client_event ) return chunk diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 8690f69d45..22e0a04da4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler): ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: - logger.warn( + logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 3e4d8c93a4..e3b528d271 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler): self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. @@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler): the read marker has changed. """ - with (yield self.read_marker_linearizer.queue((room_id, user_id))): - existing_read_marker = yield self.store.get_account_data_for_room_and_type( + with await self.read_marker_linearizer.queue((room_id, user_id)): + existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) @@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream - should_update = yield self.store.is_event_after( + should_update = await self.store.is_event_after( event_id, existing_read_marker["event_id"] ) if should_update: content = {"event_id": event_id} - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6854c751a6..9283c039e3 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - @defer.inlineCallbacks - def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler): ) ) - yield self._handle_new_receipts(receipts) + await self._handle_new_receipts(receipts) - @defer.inlineCallbacks - def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ min_batch_id = None max_batch_id = None for receipt in receipts: - res = yield self.store.insert_receipt( + res = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + await maybe_awaitable( + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) ) return True - @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler): data={"ts": int(self.clock.time_msec())}, ) - is_new = yield self._handle_new_receipts([receipt]) + is_new = await self._handle_new_receipts([receipt]) if not is_new: return - yield self.federation.send_read_receipt(receipt) - - @defer.inlineCallbacks - def get_receipts_for_room(self, room_id, to_key): - """Gets all receipts for a room, upto the given key. - """ - result = yield self.store.get_linearized_receipts_for_room( - room_id, to_key=to_key - ) - - if not result: - return [] - - return result + await self.federation.send_read_receipt(receipt) class ReceiptEventSource(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 53410f120b..235f11c322 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -24,7 +24,6 @@ from synapse.api.errors import ( AuthError, Codes, ConsentNotGivenError, - LimitExceededError, RegistrationError, SynapseError, ) @@ -168,6 +167,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ + yield self.check_registration_ratelimit(address) yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None @@ -217,8 +217,13 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID + fail_count = 0 user = None while not user: + # Fail after being unable to find a suitable ID a few times + if fail_count > 10: + raise SynapseError(500, "Unable to find a suitable guest user ID") + localpart = yield self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -233,10 +238,14 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=default_display_name, address=address, ) + + # Successfully registered + break except SynapseError: # if user id is taken, just generate another user = None user_id = None + fail_count += 1 if not self.hs.config.user_consent_at_registration: yield self._auto_join_rooms(user_id) @@ -396,8 +405,8 @@ class RegistrationHandler(BaseHandler): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield room_member_handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_alias ) room_id = room_id.to_string() else: @@ -414,6 +423,29 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) + def check_registration_ratelimit(self, address): + """A simple helper method to check whether the registration rate limit has been hit + for a given IP address + + Args: + address (str|None): the IP address used to perform the registration. If this is + None, no ratelimiting will be performed. + + Raises: + LimitExceededError: If the rate limit has been exceeded. + """ + if not address: + return + + time_now = self.clock.time() + + self.ratelimiter.ratelimit( + address, + time_now_s=time_now, + rate_hz=self.hs.config.rc_registration.per_second, + burst_count=self.hs.config.rc_registration.burst_count, + ) + def register_with_store( self, user_id, @@ -446,22 +478,6 @@ class RegistrationHandler(BaseHandler): Returns: Deferred """ - # Don't rate limit for app services - if appservice_id is None and address is not None: - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) - if self.hs.config.worker_app: return self._register_client( user_id=user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..e92b2eafd5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler): old_room_id, new_version, # args for _upgrade_room ) + return ret @defer.inlineCallbacks @@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler): # we create and auth the tombstone event before properly creating the new # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, + ( + tombstone_event, + tombstone_context, + ) = yield self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, }, - token_id=requester.access_token_id, - ) + }, + token_id=requester.access_token_id, ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( @@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler): requester, old_room_id, new_room_id, old_room_state ) - # and finally, shut down the PLs in the old room, and update them in the new + # Copy over user push rules, tags and migrate room directory state + yield self.room_member_handler.transfer_room_state_on_room_upgrade( + old_room_id, new_room_id + ) + + # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state @@ -822,6 +829,8 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, event_filter): @@ -848,7 +857,7 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = yield self.store.get_event( @@ -890,7 +899,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -922,7 +931,7 @@ class RoomEventSource(object): from_token = RoomStreamToken.parse(from_key) if from_token.topological: - logger.warn("Stream has topological part!!!! %r", from_key) + logger.warning("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) app_service = self.store.get_app_service_by_user_id(user.to_string()) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 380e2fad5e..6cfee4b361 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -203,10 +203,6 @@ class RoomMemberHandler(object): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - # Copy over user state if we're joining an upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -455,11 +451,6 @@ class RoomMemberHandler(object): requester, remote_room_hosts, room_id, target, content ) - # Copy over user state if this is a join on an remote upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) - return remote_join_response elif effective_membership_state == Membership.LEAVE: @@ -498,36 +489,81 @@ class RoomMemberHandler(object): return res @defer.inlineCallbacks - def copy_user_state_if_room_upgrade(self, new_room_id, user_id): - """Copy user-specific information when they join a new room if that new room is the + def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + """Upon our server becoming aware of an upgraded room, either by upgrading a room + ourselves or joining one, we can transfer over information from the previous room. + + Copies user state (tags/push rules) for every local user that was in the old room, as + well as migrating the room directory state. + + Args: + old_room_id (str): The ID of the old room + + room_id (str): The ID of the new room + + Returns: + Deferred + """ + # Find all local users that were in the old room and copy over each user's state + users = yield self.store.get_users_in_room(old_room_id) + yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + + # Add new room to the room directory if the old room was there + # Remove old room from the room directory + old_room = yield self.store.get_room(old_room_id) + if old_room and old_room["is_public"]: + yield self.store.set_room_is_public(old_room_id, False) + yield self.store.set_room_is_public(room_id, True) + + # Check if any groups we own contain the predecessor room + local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) + for group_id in local_group_ids: + # Add new the new room to those groups + yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) + + # Remove the old room from those groups + yield self.store.remove_room_from_group(group_id, old_room_id) + + @defer.inlineCallbacks + def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - new_room_id (str): The ID of the room the user is joining - user_id (str): The ID of the user + old_room_id (str): The ID of upgraded room + new_room_id (str): The ID of the new room + user_ids (Iterable[str]): User IDs to copy state for Returns: Deferred """ - # Check if the new room is an upgraded room - predecessor = yield self.store.get_room_predecessor(new_room_id) - if not predecessor: - return logger.debug( - "Found predecessor for %s: %s. Copying over room tags and push " "rules", + "Copying over room tags and push rules from %s to %s for users %s", + old_room_id, new_room_id, - predecessor, + user_ids, ) - # It is an upgraded room. Copy over old tags - yield self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], new_room_id, user_id - ) - # Copy over push rules - yield self.store.copy_push_rules_from_room_to_room_for_user( - predecessor["room_id"], new_room_id, user_id - ) + for user_id in user_ids: + try: + # It is an upgraded room. Copy over old tags + yield self.copy_room_tags_and_direct_to_room( + old_room_id, new_room_id, user_id + ) + # Copy over push rules + yield self.store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + except Exception: + logger.exception( + "Error copying tags and/or push rules from rooms %s to %s for user %s. " + "Skipping...", + old_room_id, + new_room_id, + user_id, + ) + continue @defer.inlineCallbacks def send_membership_event(self, requester, event, context, ratelimit=True): @@ -759,22 +795,25 @@ class RoomMemberHandler(object): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_keys, fallback_public_key, display_name = ( - yield self.identity_handler.ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) + ( + token, + public_keys, + fallback_public_key, + display_name, + ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + requester=requester, + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url, + id_access_token=id_access_token, ) yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..56ed262a1f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -35,6 +35,8 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -221,7 +223,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -271,7 +273,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -340,11 +342,11 @@ class SearchHandler(BaseHandler): ) res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + self.storage, user.to_string(), res["events_before"] ) res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -372,7 +374,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = yield self.state_store.get_state_for_event( last_event_id, state_filter ) @@ -394,15 +396,11 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = ( - yield self._event_serializer.serialize_events( - context["events_before"], time_now - ) + context["events_before"] = yield self._event_serializer.serialize_events( + context["events_before"], time_now ) - context["events_after"] = ( - yield self._event_serializer.serialize_events( - context["events_after"], time_now - ) + context["events_after"] = yield self._event_serializer.serialize_events( + context["events_after"], time_now ) state_results = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 466daf9202..7f7d56390e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -45,6 +45,8 @@ class StatsHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.stats_bucket_size = hs.config.stats_bucket_size + self.stats_enabled = hs.config.stats_enabled + # The current position in the current_state_delta stream self.pos = None @@ -61,7 +63,7 @@ class StatsHandler(StateDeltasHandler): def notify_new_event(self): """Called when there may be more deltas to process """ - if not self.hs.config.stats_enabled or self._is_processing: + if not self.stats_enabled or self._is_processing: return self._is_processing = True @@ -106,7 +108,10 @@ class StatsHandler(StateDeltasHandler): user_deltas = {} # Then count deltas for total_events and total_event_bytes. - room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + ( + room_count, + user_count, + ) = yield self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d99160e9d7..b536d410e5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -230,6 +230,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -417,7 +419,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -470,7 +472,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) loaded_recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -509,7 +511,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -580,7 +582,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -757,11 +759,11 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -781,7 +783,7 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: @@ -810,7 +812,7 @@ class SyncHandler(object): ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: @@ -841,7 +843,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -1204,10 +1206,11 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_updated_account_data_for_user( + user_id, since_token.account_data_key ) push_rules_changed = yield self.store.have_push_rules_changed_for_user( @@ -1219,9 +1222,10 @@ class SyncHandler(object): sync_config.user ) else: - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(sync_config.user.to_string()) - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 29aa1e5aaf..8363d887a9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key diff --git a/synapse/http/client.py b/synapse/http/client.py index cdf828a4ff..d4c285445e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,6 +45,7 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred @@ -183,7 +184,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +201,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -236,11 +247,13 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: @@ -535,7 +548,7 @@ class SimpleHttpClient(object): b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), @@ -543,7 +556,7 @@ class SimpleHttpClient(object): ) if response.code > 299: - logger.warn("Got %d when downloading %s" % (response.code, url)) + logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 3fe4ffb9e5..021b233a7d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -148,7 +148,7 @@ class SrvResolver(object): # Try something in the cache, else rereaise cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( + logger.warning( "Failed to resolve %r, falling back to cache. %r", service_name, e ) return list(cache_entry) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3f7c93ffcb..691380abda 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): body = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, @@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object): except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. - logger.warn( + logger.warning( "{%s} [%s] Failed to get error response: %s %s: %s", request.txn_id, request.destination, @@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object): raise except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object): d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://<host>[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 46af27c8f6..58f9cc61c8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -170,7 +170,7 @@ class RequestMetrics(object): tag = context.tag if context != self.start_context: - logger.warn( + logger.warning( "Context have unexpectedly changed %r, %r", context, self.start_context, diff --git a/synapse/http/server.py b/synapse/http/server.py index 2ccb210fd6..943d12c907 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -454,7 +454,7 @@ def respond_with_json( # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: - logger.warn( + logger.warning( "Not sending response to request %s, already disconnected.", request ) return diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 274c1a6a87..e9a5e46ced 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False): try: content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: - logger.warn("Unable to decode UTF-8") + logger.warning("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) try: content = json.loads(content_unicode) except Exception as e: - logger.warn("Unable to parse JSON: %s", e) + logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index df5274c177..ff8184a3d0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -199,7 +199,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warn( + logger.warning( "Error processing request %r: %s %s", self, reason.type, reason.value ) @@ -305,7 +305,7 @@ class SynapseRequest(Request): try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + logger.warning("Failed to stop metrics: %r", e) class XForwardedForRequest(SynapseRequest): diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 3220e985a9..334ddaf39a 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} def parse_drain_configs( - drains: dict + drains: dict, ) -> typing.Generator[DrainConfiguration, None, None]: """ Parse the drain configurations. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..2c1fb9ddac 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -294,7 +294,7 @@ class LoggingContext(object): """Enters this logging context into thread local storage""" old_context = self.set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..af161a81d7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,7 @@ class Notifier(object): self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -425,7 +426,10 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..1ba7bcd4d8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) @@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object): room_members = yield self.store.get_joined_users_from_context(event, context) - (power_levels, sender_power_level) = ( - yield self._get_power_levels_and_sender_level(event, context) - ) + ( + power_levels, + sender_power_level, + ) = yield self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 42e5b0c0a5..8c818a86bf 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -234,14 +234,12 @@ class EmailPusher(object): return self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..e994037be6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -64,6 +64,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -102,7 +103,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] @@ -210,14 +211,12 @@ class HttpPusher(object): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -246,7 +245,7 @@ class HttpPusher(object): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn( + logger.warning( "Giving up on a notification to user %s, " "pushkey %s", self.user_id, self.pushkey, @@ -299,7 +298,7 @@ class HttpPusher(object): if pk != self.pushkey: # for sanity, we only remove the pushkey if it # was the one we actually sent... - logger.warn( + logger.warning( ("Ignoring rejected pushkey %s because we" " didn't send it"), pk, ) @@ -329,7 +328,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5b16ab4ae8..1d15a06a58 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -119,6 +119,7 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) @@ -389,7 +390,7 @@ class Mailer(object): } the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 5ed9147de4..b1587183a8 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object): pattern = UserID.from_string(user_id).localpart if not pattern: - logger.warn("event_match condition with no pattern") + logger.warning("event_match condition with no pattern") return False # XXX: optimisation: cache our pattern regexps @@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False): regex_cache[(glob, word_boundary)] = r return r.search(value) except re.error: - logger.warn("Failed to parse glob to regex: %r", glob) + logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..de5c101a58 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 08e840fdc2..0f6992202d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -103,9 +103,7 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = ( - yield self.store.get_latest_push_action_stream_ordering() - ) + last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() yield self.store.add_pusher( user_id=user_id, diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index aa7da1c543..5871feaafd 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -61,7 +61,6 @@ REQUIREMENTS = [ "bcrypt>=3.1.0", "pillow>=4.3.0", "sortedcontainers>=1.4.4", - "psutil>=2.0.0", "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..c8056b0c0c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -110,14 +110,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass @@ -180,7 +180,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2f16955954..9af4e7e173 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request): + async def _handle_request(self, request): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = yield EventContext.deserialize( - self.store, event_payload["context"] - ) + context = EventContext.deserialize(self.store, event_payload["context"]) event_and_contexts.append((event, context)) logger.info("Got %d events from federation", len(event_and_contexts)) - yield self.federation_handler.persist_events_and_notify( + await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) @@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} - @defer.inlineCallbacks - def _handle_request(self, request, edu_type): + async def _handle_request(self, request, edu_type): with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = yield self.registry.on_edu(edu_type, origin, edu_content) + result = await self.registry.on_edu(edu_type, origin, edu_content) return 200, result @@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): """ return {"args": args} - @defer.inlineCallbacks - def _handle_request(self, request, query_type): + async def _handle_request(self, request, query_type): with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): logger.info("Got %r query", query_type) - result = yield self.registry.on_query(query_type, args) + result = await self.registry.on_query(query_type, args) return 200, result @@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): """ return {} - @defer.inlineCallbacks - def _handle_request(self, request, room_id): - yield self.store.clean_room_for_join(room_id) + async def _handle_request(self, request, room_id): + await self.store.clean_room_for_join(room_id) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 786f5232b2..798b9d3af5 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest ) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b9ce3477ad..cc1f249740 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID @@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - yield self.federation_handler.do_invite_join( + await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) @@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "remote_room_hosts": remote_room_hosts, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = yield self.federation_handler.do_remotely_reject_invite( + event = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() @@ -148,9 +144,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # - logger.warn("Failed to reject invite: %s", e) + logger.warning("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(user_id, room_id) + await self.store.locally_reject_invite(user_id, room_id) ret = {} return 200, ret diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 38260256cf..0c4aca1291 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -74,11 +72,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "address": address, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.registration_handler.register_with_store( + self.registration_handler.check_registration_ratelimit(content["address"]) + + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], was_guest=content["was_guest"], @@ -117,14 +116,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): """ return {"auth_result": auth_result, "access_token": access_token} - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) auth_result = content["auth_result"] access_token = content["access_token"] - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index adb9b2f7f4..9bafd60b14 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request, event_id): + async def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.store, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] @@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - yield self.event_creation_handler.persist_and_notify_client_event( + await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 182cb2a1d8..456bc005a0 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict import six @@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore): self.hs = hs - def stream_positions(self): + def stream_positions(self) -> Dict[str, int]: + """ + Get the current positions of all the streams this store wants to subscribe to + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ pos = {} if self._cache_id_gen: pos["caches"] = self._cache_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 3c44d1d48d..bc2f6a12ae 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -16,8 +16,8 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.account_data import AccountDataWorkerStore -from synapse.storage.tags import TagsWorkerStore +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.storage.data_stores.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index cda12ea70d..a67fbeffb7 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 14ced32333..b4f58cea19 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 284fd30d89..9fb6c5c6ff 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -15,7 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index f045e1b937..de50748c30 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -15,8 +15,9 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.devices import DeviceWorkerStore -from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.storage.data_stores.main.devices import DeviceWorkerStore +from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() - result["device_lists"] = self._device_list_id_gen.get_current_token() + # The user signature stream uses the same stream ID generator as the + # device list stream, so set them both to the device list ID + # generator's current token. + current_token = self._device_list_id_gen.get_current_token() + result[DeviceListsStream.NAME] = current_token + result[UserSignatureStream.NAME] = current_token return result def process_replication_rows(self, stream_name, token, rows): - if stream_name == "device_lists": + if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices(token, row.user_id, row.destination) + elif stream_name == UserSignatureStream.NAME: + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 1d1d48709a..8b9717c46f 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.directory import DirectoryWorkerStore +from synapse.storage.data_stores.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index ab5937e638..d0a0eaf75b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -20,15 +20,17 @@ from synapse.replication.tcp.streams.events import ( EventsStreamCurrentStateRow, EventsStreamEventRow, ) -from synapse.storage.event_federation import EventFederationWorkerStore -from synapse.storage.event_push_actions import EventPushActionsWorkerStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.relations import RelationsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.storage.signatures import SignatureWorkerStore -from synapse.storage.state import StateGroupWorkerStore -from synapse.storage.stream import StreamWorkerStore -from synapse.storage.user_erasure_store import UserErasureWorkerStore +from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.relations import RelationsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.stream import StreamWorkerStore +from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 456a14cd5c..5c84ebd125 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.filtering import FilteringStore +from synapse.storage.data_stores.main.filtering import FilteringStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index cc6f7f009f..3def367ae9 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import KeyStore +from synapse.storage.data_stores.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 82d808af4c..747ced0c84 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.storage import DataStore -from synapse.storage.presence import PresenceStore +from synapse.storage.data_stores.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 46c28d4171..28c508aad3 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.profile import ProfileWorkerStore +from synapse.storage.data_stores.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index af7012702e..3655f05e54 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.push_rule import PushRulesWorkerStore +from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 8eeb267d61..b4331d0799 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 91afa5a72b..43d823c601 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 408d91df1c..4b8553e250 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.registration import RegistrationWorkerStore +from synapse.storage.data_stores.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f68b3378e3..d9ad386b28 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.room import RoomWorkerStore +from synapse.storage.data_stores.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 3527beb3c9..ac88e6b8c3 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.transactions import TransactionStore +from synapse.storage.data_stores.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..fead78388c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,10 +16,17 @@ """ import logging +from typing import Dict from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( FederationAckCommand, InvalidateCacheCommand, @@ -27,7 +34,6 @@ from .commands import ( UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -138,11 +144,13 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) @@ -168,7 +176,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5ffdf2675d..afaf002fe6 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct @@ -65,6 +65,7 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.stringutils import random_string from .commands import ( @@ -249,7 +250,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return handler(cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() @@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) +class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): + """ + The interface for the handler that should be passed to + ClientReplicationStreamProtocol + """ + + @abc.abstractmethod + def on_rdata(self, stream_name, token, rows): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_position(self, stream_name, token): + """Called when we get new position data.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_sync(self, data): + """Called when get a new SYNC command.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users.""" + raise NotImplementedError() + + @abc.abstractmethod + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + raise NotImplementedError() + + @abc.abstractmethod + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + raise NotImplementedError() + + class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): + def __init__( + self, + client_name: str, + server_name: str, + clock: Clock, + handler: AbstractReplicationClientHandler, + ): BaseReplicationStreamProtocol.__init__(self, clock) self.client_name = client_name diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..5f52264e84 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -45,5 +45,6 @@ STREAMS_MAP = { _base.TagAccountDataStream, _base.AccountDataStream, _base.GroupServerStream, + _base.UserSignatureStream, ) } diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..9e45429d49 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple( "GroupsStreamRow", ("group_id", "user_id", "type", "content"), # str # str # str # dict ) +UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str class Stream(object): @@ -438,3 +439,20 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) + + +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + NAME = "user_signature" + _LIMITED = False + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_user_signature_changes_for_remotes + + super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 939418ee2b..5c2a2eb593 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, stream_ordering ) if not r: - logger.warn( + logger.warning( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", ts, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8414af08cb..24a0ce74f2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet): address = address.lower() # Check for login providers that support 3pid login types - canonical_user_id, callback_3pid = ( - yield self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) + ( + canonical_user_id, + callback_3pid, + ) = yield self.auth_handler.check_password_provider_3pid( + medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded @@ -221,7 +222,7 @@ class LoginRestServlet(RestServlet): medium, address ) if not user_id: - logger.warn( + logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet): def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = ( - yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + token ) result = yield self._register_device_with_callback(user_id, login_submission) @@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet): self.cas_displayname_attribute = hs.config.cas_displayname_attribute self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c1d41421c..86bbcc0eea 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet): set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) if txn_id: set_tag("txn_id", txn_id) @@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet): elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) @@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token, @@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): third_party_instance_id=third_party_instance_id, ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) @@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} @@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content and membership_action in ["kick", "ban"]: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) @@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 80cf7126a0..f26eae794c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_password_reset_template_failure_html], ) @@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_add_threepid_template_failure_html], ) @@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 151a70d449..341567ae21 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -274,7 +274,57 @@ class SigningKeyUploadServlet(RestServlet): ) result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) - return (200, result) + return 200, result + + +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "<device_id>": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<signing_user_id>": { + "<algorithm>:<signing_key_base64>": "<signature_base64>>" + } + } + } + } + } + """ + + PATTERNS = client_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SignaturesUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) + return 200, result def register_servlets(hs, http_server): @@ -283,3 +333,4 @@ def register_servlets(hs, http_server): KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 4f24a124a6..91db923814 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) @@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet): # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a883c8adda..ccd8b17b23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -112,9 +112,14 @@ class SyncRestServlet(RestServlet): full_state = parse_boolean(request, "full_state", default=False) logger.debug( - "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" - % (user, timeout, since, set_presence, filter_id, device_id) + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, ) request_key = (user, timeout, since, filter_id, full_state, device_id) @@ -389,7 +394,7 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 1044ae7b4e..bb30ce3f34 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet): "m.require_identity_server": False, # as per MSC2290 "m.separate_add_and_bind": True, + # Implements support for label-based filtering as described in + # MSC2326. + "org.matrix.label_based_filtering": True, }, }, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 55580bc59e..e7fc3f0431 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource): @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: - server, = request.postpath + (server,) = request.postpath query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index b972e152a9..bd9186fe50 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -363,7 +363,7 @@ class MediaRepository(object): }, ) except RequestSendFailed as e: - logger.warn( + logger.warning( "Request failed fetching remote media %s/%s: %r", server_name, media_id, @@ -372,7 +372,7 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn( + logger.warning( "HTTP error fetching remote media %s/%s: %s", server_name, media_id, @@ -383,10 +383,12 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) + logger.warning("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception( @@ -691,7 +693,7 @@ class MediaRepository(object): try: os.remove(full_path) except OSError as e: - logger.warn("Failed to remove file: %r", full_path) + logger.warning("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0c68c3aad5..15c15a12f5 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -56,6 +56,9 @@ logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +OG_TAG_NAME_MAXLEN = 50 +OG_TAG_VALUE_MAXLEN = 1000 + class PreviewUrlResource(DirectServeResource): isLeaf = True @@ -74,6 +77,8 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path @@ -117,8 +122,10 @@ class PreviewUrlResource(DirectServeResource): pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug( - ("Matching attrib '%s' with value '%s' against" " pattern '%s'") - % (attrib, value, pattern) + "Matching attrib '%s' with value '%s' against" " pattern '%s'", + attrib, + value, + pattern, ) if value is None: @@ -134,7 +141,7 @@ class PreviewUrlResource(DirectServeResource): match = False continue if match: - logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) @@ -167,7 +174,7 @@ class PreviewUrlResource(DirectServeResource): ts (int): Returns: - Deferred[str]: json-encoded og data + Deferred[bytes]: json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) @@ -186,7 +193,7 @@ class PreviewUrlResource(DirectServeResource): media_info = yield self._download_url(url, user) - logger.debug("got media_info of '%s'" % media_info) + logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] @@ -206,7 +213,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % url) + logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info["media_type"]): @@ -254,7 +261,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) + logger.warning("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -265,10 +272,22 @@ class PreviewUrlResource(DirectServeResource): else: del og["og:image"] else: - logger.warn("Failed to find any OG data in %s", url) + logger.warning("Failed to find any OG data in %s", url) og = {} - logger.debug("Calculated OG for %s as %s" % (url, og)) + # filter out any stupidly long values + keys_to_remove = [] + for k, v in og.items(): + # values can be numeric as well as strings, hence the cast to str + if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN: + logger.warning( + "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN] + ) + keys_to_remove.append(k) + for k in keys_to_remove: + del og[k] + + logger.debug("Calculated OG for %s as %s", url, og) jsonog = json.dumps(og) @@ -297,7 +316,7 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'" % url) + logger.debug("Trying to get url '%s'", url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size ) @@ -317,7 +336,7 @@ class PreviewUrlResource(DirectServeResource): ) except Exception as e: # FIXME: pass through 404s and other error messages nicely - logger.warn("Error downloading %s: %r", url, e) + logger.warning("Error downloading %s: %r", url, e) raise SynapseError( 500, @@ -398,7 +417,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -430,7 +449,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue try: @@ -446,7 +465,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -502,6 +521,10 @@ def _calc_og(tree, media_uri): og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 08329884ac..931ce79be8 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks @@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..90c3b072e8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail @@ -95,6 +96,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -167,6 +169,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -196,6 +199,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "storage", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -217,6 +221,7 @@ class HomeServer(object): self.hostname = hostname self._building = {} self._listening_services = [] + self.start_time = None self.clock = Clock(reactor) self.distributor = Distributor() @@ -224,7 +229,7 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.datastore = None + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: @@ -233,8 +238,10 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) + datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(datastore, conn, self) conn.commit() + self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") def setup_master(self): @@ -266,7 +273,7 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main def get_config(self): return self.config @@ -308,6 +315,13 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) @@ -537,6 +551,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..b5e0b57095 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -12,6 +12,7 @@ import synapse.handlers.message import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -38,8 +39,16 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +56,32 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 81c4aff496..9fae2e0afe 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import ( EventTypes, + LimitBlockingTypes, ServerNoticeLimitReached, ServerNoticeMsgType, ) @@ -70,7 +71,7 @@ class ResourceLimitsServerNotices(object): return if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unles they've been enabled + # Don't try and send server notices unless they've been enabled return timestamp = yield self._store.user_last_seen_monthly_active(user_id) @@ -79,60 +80,93 @@ class ResourceLimitsServerNotices(object): # In practice, not sure we can ever get here return - # Determine current state of room - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: - logger.warn("Failed to get server notices room") + logger.warning("Failed to get server notices room") return yield self._check_and_set_tags(user_id, room_id) + + # Determine current state of room currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + limit_msg = None + limit_type = None try: - # Normally should always pass in user_id if you have it, but in - # this case are checking what would happen to other users if they - # were to arrive. - try: - yield self._auth.check_auth_blocking() - is_auth_blocking = False - except ResourceLimitError as e: - is_auth_blocking = True - event_content = e.msg - event_limit_type = e.limit_type - - if currently_blocked and not is_auth_blocking: - # Room is notifying of a block, when it ought not to be. - # Remove block notification - content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) + # Normally should always pass in user_id to check_auth_blocking + # if you have it, but in this case are checking what would happen + # to other users if they were to arrive. + yield self._auth.check_auth_blocking() + except ResourceLimitError as e: + limit_msg = e.msg + limit_type = e.limit_type - elif not currently_blocked and is_auth_blocking: + try: + if ( + limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER + and not self._config.mau_limit_alerting + ): + # We have hit the MAU limit, but MAU alerting is disabled: + # reset room if necessary and return + if currently_blocked: + self._remove_limit_block_notification(user_id, ref_events) + return + + if currently_blocked and not limit_msg: + # Room is notifying of a block, when it ought not to be. + yield self._remove_limit_block_notification(user_id, ref_events) + elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - # Add block notification - content = { - "body": event_content, - "msgtype": ServerNoticeMsgType, - "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, - "limit_type": event_limit_type, - } - event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message + yield self._apply_limit_block_notification( + user_id, limit_msg, limit_type ) - - content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) - except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) @defer.inlineCallbacks + def _remove_limit_block_notification(self, user_id, ref_events): + """Utility method to remove limit block notifications from the server + notices room. + + Args: + user_id (str): user to notify + ref_events (list[str]): The event_ids of pinned events that are unrelated to + limit blocking and need to be preserved. + """ + content = {"pinned": ref_events} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + @defer.inlineCallbacks + def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + """Utility method to apply limit block notifications in the server + notices room. + + Args: + user_id (str): user to notify + event_body(str): The human readable text that describes the block. + event_limit_type(str): Specifies the type of block e.g. monthly active user + limit has been exceeded. + """ + content = { + "body": event_body, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, + } + event = yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message + ) + + content = {"pinned": [event.event_id]} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + @defer.inlineCallbacks def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py new file mode 100644 index 0000000000..efcc10f808 --- /dev/null +++ b/synapse/spam_checker_api/__init__.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +class SpamCheckerApi(object): + """A proxy object that gets passed to spam checkers so they can get + access to rooms and other relevant information. + """ + + def __init__(self, hs): + self.hs = hs + + self._store = hs.get_datastore() + + @defer.inlineCallbacks + def get_state_events_in_room(self, room_id, types): + """Gets state events for the given room. + + Args: + room_id (string): The room ID to get state events in. + types (tuple): The event type and state key (using None + to represent 'any') of the room state to acquire. + + Returns: + twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: + The filtered state events in the room. + """ + state_ids = yield self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) + state = yield self._store.get_events(state_ids.values()) + return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dc9f5a9008..139beef8ed 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,6 +16,7 @@ import logging from collections import namedtuple +from typing import Iterable, Optional from six import iteritems, itervalues @@ -27,6 +28,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 @@ -103,6 +105,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -211,15 +214,17 @@ class StateHandler(object): return joined_hosts @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None): + def compute_event_context( + self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + ): """Build an EventContext structure for the event. This works out what the current state should be for the event, and generates a new state group if necessary. Args: - event (synapse.events.EventBase): - old_state (dict|None): The state at the event if it can't be + event: + old_state: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. @@ -231,6 +236,9 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + + # FIXME: why do we populate current_state_ids? I thought the point was + # that we weren't supposed to have any state for outliers? if old_state: prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): @@ -247,113 +255,103 @@ class StateHandler(object): # group for it. context = EventContext.with_state( state_group=None, + state_group_before_event=None, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, ) return context + # + # first of all, figure out the state before the event + # + if old_state: - # We already have the state, so we don't need to calculate it. - # Let's just correctly fill out the context and create a - # new state group for it. - - prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} - - if event.is_state(): - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] - if replaces != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - else: - current_state_ids = prev_state_ids + # if we're given the state before the event, then we use that + state_ids_before_event = { + (s.type, s.state_key): s.event_id for s in old_state + } + state_group_before_event = None + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=current_state_ids, - ) + else: + # otherwise, we'll need to resolve the state across the prev_events. + logger.debug("calling resolve_state_groups from compute_event_context") - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + entry = yield self.resolve_state_groups_for_events( + event.room_id, event.prev_event_ids() ) - return context + state_ids_before_event = entry.state + state_group_before_event = entry.state_group + state_group_before_event_prev_group = entry.prev_group + deltas_to_state_group_before_event = entry.delta_ids - logger.debug("calling resolve_state_groups from compute_event_context") + # + # make sure that we have a state group at that point. If it's not a state event, + # that will be the state group for the new event. If it *is* a state event, + # it might get rejected (in which case we'll need to persist it with the + # previous state group) + # - entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() - ) + if not state_group_before_event: + state_group_before_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) - prev_state_ids = entry.state - prev_group = None - delta_ids = None + # XXX: can we update the state cache entry for the new state group? or + # could we set a flag on resolve_state_groups_for_events to tell it to + # always make a state group? + + # + # now if it's not a state event, we're done + # + + if not event.is_state(): + return EventContext.with_state( + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + current_state_ids=state_ids_before_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + ) - if event.is_state(): - # If this is a state event then we need to create a new state - # group for the state after this event. + # + # otherwise, we'll need to create a new state group for after the event + # - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] + key = (event.type, event.state_key) + if key in state_ids_before_event: + replaces = state_ids_before_event[key] + if replaces != event.event_id: event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - - if entry.state_group: - # If the state at the event has a state group assigned then - # we can use that as the prev group - prev_group = entry.state_group - delta_ids = {key: event.event_id} - elif entry.prev_group: - # If the state at the event only has a prev group, then we can - # use that as a prev group too. - prev_group = entry.prev_group - delta_ids = dict(entry.delta_ids) - delta_ids[key] = event.event_id - - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=delta_ids, - current_state_ids=current_state_ids, - ) - else: - current_state_ids = prev_state_ids - prev_group = entry.prev_group - delta_ids = entry.delta_ids - - if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=entry.prev_group, - delta_ids=entry.delta_ids, - current_state_ids=current_state_ids, - ) - entry.state_id = entry.state_group - - state_group = entry.state_group - - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, + state_ids_after_event = dict(state_ids_before_event) + state_ids_after_event[key] = event.event_id + delta_ids = {key: event.event_id} + + state_group_after_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, delta_ids=delta_ids, + current_state_ids=state_ids_after_event, ) - return context + return EventContext.with_state( + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + current_state_ids=state_ids_after_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event, + delta_ids=delta_ids, + ) @measure_func() @defer.inlineCallbacks @@ -376,14 +374,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e9a9c2cd8d..0460fe8cc9 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,515 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar -import logging -import time - -from twisted.internet import defer - -from synapse.api.constants import PresenceState -from synapse.storage.devices import DeviceStore -from synapse.storage.user_erasure_store import UserErasureStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from .account_data import AccountDataStore -from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .client_ips import ClientIpStore -from .deviceinbox import DeviceInboxStore -from .directory import DirectoryStore -from .e2e_room_keys import EndToEndRoomKeyStore -from .end_to_end_keys import EndToEndKeyStore -from .engines import PostgresEngine -from .event_federation import EventFederationStore -from .event_push_actions import EventPushActionsStore -from .events import EventsStore -from .events_bg_updates import EventsBackgroundUpdatesStore -from .filtering import FilteringStore -from .group_server import GroupServerStore -from .keys import KeyStore -from .media_repository import MediaRepositoryStore -from .monthly_active_users import MonthlyActiveUsersStore -from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState -from .profile import ProfileStore -from .push_rule import PushRuleStore -from .pusher import PusherStore -from .receipts import ReceiptsStore -from .registration import RegistrationStore -from .rejections import RejectionsStore -from .relations import RelationsStore -from .room import RoomStore -from .roommember import RoomMemberStore -from .search import SearchStore -from .signatures import SignatureStore -from .state import StateStore -from .stats import StatsStore -from .stream import StreamStore -from .tags import TagsStore -from .transactions import TransactionStore -from .user_directory import UserDirectoryStore -from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator - -logger = logging.getLogger(__name__) - - -class DataStore( - EventsBackgroundUpdatesStore, - RoomMemberStore, - RoomStore, - RegistrationStore, - StreamStore, - ProfileStore, - PresenceStore, - TransactionStore, - DirectoryStore, - KeyStore, - StateStore, - SignatureStore, - ApplicationServiceStore, - EventsStore, - EventFederationStore, - MediaRepositoryStore, - RejectionsStore, - FilteringStore, - PusherStore, - PushRuleStore, - ApplicationServiceTransactionStore, - ReceiptsStore, - EndToEndKeyStore, - EndToEndRoomKeyStore, - SearchStore, - TagsStore, - AccountDataStore, - EventPushActionsStore, - OpenIdStore, - ClientIpStore, - DeviceStore, - DeviceInboxStore, - UserDirectoryStore, - GroupServerStore, - UserErasureStore, - MonthlyActiveUsersStore, - StatsStore, - RelationsStore, -): - def __init__(self, db_conn, hs): - self.hs = hs - self._clock = hs.get_clock() - self.database_engine = hs.database_engine - - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" - ) - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) - self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id" - ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) - - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - - if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" - ) - else: - self._cache_id_gen = None - - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self._get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._user_signature_stream_cache = StreamChangeCache( - "UserSignatureStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - - super(DataStore, self).__init__(db_conn, hs) - - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - - def count_daily_users(self): - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) - - def count_monthly_users(self): - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - count, = txn.fetchone() - return count - - def count_r30_users(self): - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns counts globaly for a given user as well as breaking - by platform - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - count, = txn.fetchone() - results["all"] = count - - return results - - return self.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - def generate_user_daily_visits(self): - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - return self.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - - def get_users(self): - """Function to reterive a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_select_list( - table="users", - keyvalues={}, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="get_users", - ) - - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. - - Args: - order (str): column name to order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - "get_users_paginate", - self._simple_select_list_paginate_txn, - table="users", - keyvalues={"is_guest": False}, - orderby=order, - start=start, - limit=limit, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) - retval = {"users": users, "total": count} - return retval - - def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - col (str): column to query term should be matched to - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_search_list( - table="users", - term=term, - col="name", - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="search_users", - ) +""" +The storage layer is split up into multiple parts to allow Synapse to run +against different configurations of databases (e.g. single or multiple +databases). The `data_stores` are classes that talk directly to a single +database and have associated schemas, background updates, etc. On top of those +there are (or will be) classes that provide high level interfaces that combine +calls to multiple `data_stores`. + +There are also schemas that get applied to every database, regardless of the +data stores associated with them (e.g. the schema version tables), which are +stored in `synapse.storage.schema`. +""" + +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main import DataStore +from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.purge_events import PurgeEventsStorage +from synapse.storage.state import StateGroupStorage + +__all__ = ["DataStores", "DataStore"] + + +class Storage(object): + """The high level interfaces for talking to various storage layers. + """ + + def __init__(self, hs, stores: DataStores): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.persistence = EventsPersistenceStorage(hs, stores) + self.purge_events = PurgeEventsStorage(hs, stores) + self.state = StateGroupStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f5906fcd54..1a2b7ebe25 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -494,7 +494,7 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn("Starting db txn '%s' from sentinel context", desc) + logger.warning("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( @@ -532,7 +532,7 @@ class SQLBaseStore(object): """ parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: - logger.warn( + logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -719,7 +719,7 @@ class SQLBaseStore(object): raise # presumably we raced with another transaction: let's retry. - logger.warn( + logger.warning( "IntegrityError when upserting into %s; retrying: %s", table, e ) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 80b57a948c..37d469ffd7 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -94,13 +94,16 @@ class BackgroundUpdateStore(SQLBaseStore): self._all_done = False def start_doing_background_updates(self): - run_as_background_process("background_updates", self._run_background_updates) + run_as_background_process("background_updates", self.run_background_updates) @defer.inlineCallbacks - def _run_background_updates(self): + def run_background_updates(self, sleep=True): logger.info("Starting background schema updates") while True: - yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + if sleep: + yield self.hs.get_clock().sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 + ) try: result = yield self.do_next_background_update( diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py new file mode 100644 index 0000000000..cb184a98cc --- /dev/null +++ b/synapse/storage/data_stores/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class DataStores(object): + """The various data stores. + + These are low level interfaces to physical databases. + """ + + def __init__(self, main_store, db_conn, hs): + # Note we pass in the main store here as workers use a different main + # store. + self.main = main_store diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py new file mode 100644 index 0000000000..10c940df1e --- /dev/null +++ b/synapse/storage/data_stores/main/__init__.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import calendar +import logging +import time + +from twisted.internet import defer + +from synapse.api.constants import PresenceState +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import ( + ChainedIdGenerator, + IdGenerator, + StreamIdGenerator, +) +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore +from .devices import DeviceStore +from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore +from .end_to_end_keys import EndToEndKeyStore +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore +from .events import EventsStore +from .events_bg_updates import EventsBackgroundUpdatesStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .monthly_active_users import MonthlyActiveUsersStore +from .openid import OpenIdStore +from .presence import PresenceStore, UserPresenceState +from .profile import ProfileStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore +from .registration import RegistrationStore +from .rejections import RejectionsStore +from .relations import RelationsStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stats import StatsStore +from .stream import StreamStore +from .tags import TagsStore +from .transactions import TransactionStore +from .user_directory import UserDirectoryStore +from .user_erasure_store import UserErasureStore + +logger = logging.getLogger(__name__) + + +class DataStore( + EventsBackgroundUpdatesStore, + RoomMemberStore, + RoomStore, + RegistrationStore, + StreamStore, + ProfileStore, + PresenceStore, + TransactionStore, + DirectoryStore, + KeyStore, + StateStore, + SignatureStore, + ApplicationServiceStore, + EventsStore, + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore, + ApplicationServiceTransactionStore, + ReceiptsStore, + EndToEndKeyStore, + EndToEndRoomKeyStore, + SearchStore, + TagsStore, + AccountDataStore, + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, + UserErasureStore, + MonthlyActiveUsersStore, + StatsStore, + RelationsStore, +): + def __init__(self, db_conn, hs): + self.hs = hs + self._clock = hs.get_clock() + self.database_engine = hs.database_engine + + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_max_stream_id", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" + ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[("user_signature_stream", "stream_id")], + ) + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id" + ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_invalidation_stream", "stream_id" + ) + else: + self._cache_id_gen = None + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max + ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max + ) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, + "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", + min_group_updates_id, + prefilled_cache=_group_updates_prefill, + ) + + self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() + + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + + super(DataStore, self).__init__(db_conn, hs) + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return self.runInteraction("count_daily_users", self._count_users, yesterday) + + def count_monthly_users(self): + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return self.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + def generate_user_daily_visits(self): + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self.clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + return self.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) + + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_select_list( + table="users", + keyvalues={}, + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="get_users", + ) + + @defer.inlineCallbacks + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + users = yield self.runInteraction( + "get_users_paginate", + self._simple_select_list_paginate_txn, + table="users", + keyvalues={"is_guest": False}, + orderby=order, + start=start, + limit=limit, + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + ) + count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + retval = {"users": users, "total": count} + return retval + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_search_list( + table="users", + term=term, + col="name", + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="search_users", + ) diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py index 6afbfc0d74..6afbfc0d74 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py index 435b2acd4d..81babf2029 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -22,9 +22,8 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.events_worker import EventsWorkerStore - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 067820a5da..706c6a1f3f 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,11 +20,10 @@ from six import iteritems from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage import background_updates +from synapse.storage._base import Cache from synapse.util.caches import CACHE_SIZE_FACTOR -from . import background_updates -from ._base import Cache - logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index f04aad0743..f04aad0743 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py index f7a3542348..71f62036c0 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -37,6 +37,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList @@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): - """Get stream of updates to send to remote servers + def get_device_updates_by_remote(self, destination, from_stream_id, limit): + """Get a stream of device updates to send to the given remote server. + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + limit (int): Maximum number of device updates to return Returns: - Deferred[tuple[int, list[dict]]]: + Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the - response), and the list of updates + response), and the list of updates, where each update is a pair of EDU + type and EDU contents """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore): # stream_id; the rationale being that such a large device list update # is likely an error. updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, @@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] + # get the cross-signing keys of the users in the list, so that we can + # determine which of the device changes were cross-signing keys + users = set(r[0] for r in updates) + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + master_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + + cross_signing_key = yield self.get_e2e_cross_signing_key( + user, "self_signing" + ) + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. if len(updates) > limit: @@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore): # context which created the Edu. query_map = {} - for update in updates: - if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: + cross_signing_keys_by_user = {} + for user_id, device_id, update_stream_id, update_context in updates: + if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: # Stop processing updates break - key = (update[0], update[1]) - - update_context = update[3] - update_stream_id = update[2] + if ( + user_id in master_key_by_user + and device_id == master_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif ( + user_id in self_signing_key_by_user + and device_id == self_signing_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = self_signing_key_by_user[user_id][ + "key_info" + ] + else: + key = (user_id, device_id) - previous_update_stream_id, _ = query_map.get(key, (0, None)) + previous_update_stream_id, _ = query_map.get(key, (0, None)) - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) # If we didn't find any updates with a stream_id lower than the cutoff, it # means that there are more than limit updates all of which have the same @@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore): # devices, in which case E2E isn't going to work well anyway. We'll just # skip that stream_id and return an empty list, and continue with the next # stream_id next time. - if not query_map: + if not query_map and not cross_signing_keys_by_user: return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("org.matrix.signing_key_update", result)) + return now_stream_id, results - def _get_devices_by_remote_txn( + def _get_device_updates_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): """Return device update information for a given remote destination @@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: List: List of device updates """ + # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? @@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore): List[Dict]: List of objects representing an device update EDU """ - devices = yield self.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, + devices = ( + yield self.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + if query_map + else {} ) results = [] @@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(result) + results.append(("m.device_list_update", result)) return results diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py index eed7757ed5..297966d9f4 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -18,10 +18,9 @@ from collections import namedtuple from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore - RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index be2fe2bab6..1cbbae5b63 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -19,8 +19,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @@ -322,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version + self._simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index ebcd1c9ea2..073412a78d 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -21,10 +21,9 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore, db_to_json - class EndToEndKeyWorkerStore(SQLBaseStore): @trace @@ -68,6 +67,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): display_name = device_info["device_display_name"] if display_name is not None: r["unsigned"]["device_display_name"] = display_name + if "signatures" in device_info: + for sig_user_id, sigs in device_info["signatures"].items(): + r.setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) rv[user_id][device_id] = r return rv @@ -81,6 +85,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): query_clauses = [] query_params = [] + signature_query_clauses = [] + signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -91,12 +97,20 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) + signature_query_clause = "target_user_id = ?" + signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) + signature_query_clause += " AND target_device_id = ?" + signature_query_params.append(device_id) + + signature_query_clause += " AND user_id = ?" + signature_query_params.append(user_id) query_clauses.append(query_clause) + signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -123,6 +137,22 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None + # get signatures on the device + signature_sql = ( + "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s" + ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) + + txn.execute(signature_sql, signature_query_params) + rows = self.cursor_to_dict(txn) + + for row in rows: + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + if target_user_id in result and target_device_id in result[target_user_id]: + result[target_user_id][target_device_id].setdefault( + "signatures", {} + ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] + log_kv(result) return result @@ -218,6 +248,97 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the key will be included in the result + + Returns: + dict of the key data or None if not found + """ + sql = ( + "SELECT keydata " + " FROM e2e_cross_signing_keys " + " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" + ) + txn.execute(sql, (user_id, key_type)) + row = txn.fetchone() + if not row: + return None + key = json.loads(row[0]) + + device_id = None + for k in key["keys"].values(): + device_id = k + + if from_user_id is not None: + sql = ( + "SELECT key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ? " + " AND target_user_id = ? " + " AND target_device_id = ? " + ) + txn.execute(sql, (from_user_id, user_id, device_id)) + row = txn.fetchone() + if row: + key.setdefault("signatures", {}).setdefault(from_user_id, {})[ + row[0] + ] = row[1] + + return key + + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose self-signing key is being requested + key_type (str): the type of cross-signing key to get + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data or None if not found + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_key_txn, + user_id, + key_type, + from_user_id, + ) + + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + """Return a list of changes from the user signature stream to notify remotes. + Note that the user signature stream represents when a user signs their + device with their user-signing key, which is not published to other + users or servers, so no `destination` is needed in the returned + list. However, this is needed to poke workers. + + Args: + from_key (int): the stream ID to start at (exclusive) + to_key (int): the stream ID to end at (inclusive) + + Returns: + Deferred[list[(int,str)]] a list of `(stream_id, user_id)` + """ + sql = """ + SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id + """ + return self._execute( + "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): @@ -396,96 +517,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key, ) - def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the key will be included in the result - - Returns: - dict of the key data or None if not found - """ - sql = ( - "SELECT keydata " - " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" - ) - txn.execute(sql, (user_id, key_type)) - row = txn.fetchone() - if not row: - return None - key = json.loads(row[0]) - - device_id = None - for k in key["keys"].values(): - device_id = k - - if from_user_id is not None: - sql = ( - "SELECT key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ? " - " AND target_user_id = ? " - " AND target_device_id = ? " - ) - txn.execute(sql, (from_user_id, user_id, device_id)) - row = txn.fetchone() - if row: - key.setdefault("signatures", {}).setdefault(from_user_id, {})[ - row[0] - ] = row[1] - - return key - - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get - from_user_id (str): if specified, signatures made by this user on - the self-signing key will be included in the result - - Returns: - dict of the key data or None if not found - """ - return self.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_key_txn, - user_id, - key_type, - from_user_id, - ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. Args: user_id (str): the user who made the signatures - signatures (iterable[(str, str, str, str)]): signatures to add - each - a tuple of (key_id, target_user_id, target_device_id, signature), - where key_id is the ID of the key (including the signature - algorithm) that made the signature, target_user_id and - target_device_id indicate the device being signed, and signature - is the signature of the device + signatures (iterable[SignatureListItem]): signatures to add """ return self._simple_insert_many( "e2e_cross_signing_signatures", [ { "user_id": user_id, - "key_id": key_id, - "target_user_id": target_user_id, - "target_device_id": target_device_id, - "signature": signature, + "key_id": item.signing_key_id, + "target_user_id": item.target_user_id, + "target_device_id": item.target_device_id, + "signature": item.signature, } - for (key_id, target_user_id, target_device_id, signature) in signatures + for item in signatures ], "add_e2e_signing_key", ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 47cc10d32a..90bef0cd2c 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -26,8 +26,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.signatures import SignatureWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -364,9 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug( - "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit - ) + logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) event_results = set() diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 22025effbc..04ce21ac66 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) stream_row = txn.fetchone() if stream_row: - offset_stream_ordering, = stream_row + (offset_stream_ordering,) = stream_row rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) diff --git a/synapse/storage/events.py b/synapse/storage/data_stores/main/events.py index ee49ef235d..878f7568a6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,39 +17,35 @@ import itertools import logging -from collections import Counter as c_counter, OrderedDict, deque, namedtuple +from collections import Counter as c_counter, OrderedDict, namedtuple from functools import wraps from six import iteritems, text_type from six.moves import range from canonicaljson import json -from prometheus_client import Counter, Histogram +from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event_dict -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.event_federation import EventFederationStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -60,37 +56,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - def encode_json(json_object): """ @@ -102,110 +67,6 @@ def encode_json(json_object): return out -class _EventPeristenceQueue(object): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - _EventPersistQueueItem = namedtuple( - "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") - ) - - def __init__(self): - self._event_persist_queues = {} - self._currently_persisting_rooms = set() - - def add_to_queue(self, room_id, events_and_contexts, backfilled): - """Add events to the queue, with the given persist_event options. - - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - if queue: - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - end_item = queue[-1] - if end_item.backfilled == backfilled: - end_item.events_and_contexts.extend(events_and_contexts) - return end_item.deferred.observe() - - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - - queue.append( - self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - ) - ) - - return deferred.observe() - - def handle_queue(self, room_id, per_item_callback): - """Attempts to handle the queue for a room if not already being handled. - - The given callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomnes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - @defer.inlineCallbacks - def handle_queue_loop(): - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - ret = yield per_item_callback(item) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue(self, room_id): - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -221,7 +82,7 @@ def _retry_on_integrity_error(func): @defer.inlineCallbacks def f(self, *args, **kwargs): try: - res = yield func(self, *args, **kwargs) + res = yield func(self, *args, delete_existing=False, **kwargs) except self.database_engine.module.IntegrityError: logger.exception("IntegrityError, retrying.") res = yield func(self, *args, delete_existing=True, **kwargs) @@ -241,9 +102,6 @@ class EventsStore( def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() - # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count self._current_forward_extremities_amount = c_counter() @@ -286,340 +144,106 @@ class EventsStore( res = yield self.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) - @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - Deferred[int]: the stream ordering of the latest persisted event - """ - partitioned = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): - d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - deferreds.append(d) - - for room_id in partitioned: - self._maybe_start_persisting(room_id) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - - return max_persisted_id - - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False): - """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, - and the stream ordering of the latest persisted event - """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - self._maybe_start_persisting(event.room_id) - - yield make_deferred_yieldable(deferred) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) - - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): - with Measure(self._clock, "persist_events"): - yield self._persist_events( - item.events_and_contexts, backfilled=item.backfilled - ) - - self._event_persist_queue.handle_queue(room_id, persisting_queue) - @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False + def _persist_events_and_state_updates( + self, + events_and_contexts, + current_state_for_room, + state_delta_for_room, + new_forward_extremeties, + backfilled=False, + delete_existing=False, ): - """Persist events to db + """Persist a set of events alongside updates to the current state and + forward extremities tables. Args: events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + current_state_for_room (dict[str, dict]): Map from room_id to the + current state of the room based on forward extremities + state_delta_for_room (dict[str, tuple]): Map from room_id to tuple + of `(to_delete, to_insert)` where to_delete is a list + of type/state keys to remove from current state, and to_insert + is a map (type,key)->event_id giving the state delta in each + room. + new_forward_extremities (dict[str, list[str]]): Map from room_id + to list of event IDs that are the new forward extremities of + the room. + backfilled (bool) delete_existing (bool): Returns: Deferred: resolves when the events have been persisted """ - if not events_and_contexts: - return - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) + ) - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + delete_existing=delete_existing, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(events_and_contexts)) if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(chunk) + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + events_and_contexts[-1][0].internal_metadata.stream_ordering ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(chunk, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, - ) - persist_event_counter.inc(len(chunk)) - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering - ) - - for event, context in chunk: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - - # Remove any events which are prev_events of any existing events. - existing_prevs = yield self._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) + for event, context in events_and_contexts: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = yield self._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) + event_counter.labels(event.type, origin_type, origin_entity).inc() - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) + for room_id, new_state in iteritems(current_state_for_room): + self.get_current_state_ids.prefill((room_id,), new_state) - return result + for room_id, latest_event_ids in iteritems(new_forward_extremeties): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) @defer.inlineCallbacks def _get_events_which_are_prevs(self, event_ids): @@ -725,188 +349,6 @@ class EventsStore( return existing_prevs - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id (str): - room to which the events are being added. Used for logging etc - - events_context (list[(EventBase, EventContext)]): - events and contexts which are being added to the room - - old_latest_event_ids (iterable[str]): - the old forward extremities for the room. - - new_latest_event_ids (iterable[str]): - the new forward extremities for the room. - - Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events(missing_event_ids) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = set( - event_id_to_state_group[evid] for evid in old_latest_event_ids - ) - - # State groups of new_latest_event_ids - new_state_groups = set( - event_id_to_state_group[evid] for evid in new_latest_event_ids - ) - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = yield self._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = yield self.get_room_version(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self), - ) - - return res.state, None - - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. - """ - existing_state = yield self.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in iteritems(current_state) - if ev_id != existing_state.get(key) - } - - return to_delete, to_insert - @log_function def _persist_events_txn( self, @@ -1490,6 +932,13 @@ class EventsStore( self._handle_event_relations(txn, event) + # Store the labels for this event. + labels = event.content.get(EventContentFields.LABELS) + if labels: + self.insert_labels_for_event_txn( + txn, event.event_id, labels, event.room_id, event.depth + ) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1683,7 +1132,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_messages", _count_messages) @@ -1704,7 +1153,7 @@ class EventsStore( """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) @@ -1719,7 +1168,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_active_rooms", _count) @@ -1926,6 +1375,10 @@ class EventsStore( if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). + + Returns: + Deferred[set[int]]: The set of state groups that are referenced by + deleted events. """ return self.runInteraction( @@ -2062,11 +1515,10 @@ class EventsStore( [(room_id, event_id) for event_id, in new_backwards_extrems], ) - logger.info("[purge] finding redundant state groups") + logger.info("[purge] finding state groups referenced by deleted events") # Get all state groups that are referenced by events that are to be - # deleted. We then go and check if they are referenced by other events - # or state groups, and if not we delete them. + # deleted. txn.execute( """ SELECT DISTINCT state_group FROM events_to_purge @@ -2079,60 +1531,6 @@ class EventsStore( "[purge] found %i referenced state groups", len(referenced_state_groups) ) - logger.info("[purge] finding state groups that can be deleted") - - _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) - state_groups_to_delete, remaining_state_groups = _ - - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self._simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self._simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - logger.info("[purge] removing events from event_to_state_groups") txn.execute( "DELETE FROM event_to_state_groups " @@ -2204,7 +1602,7 @@ class EventsStore( """, (room_id,), ) - min_depth, = txn.fetchone() + (min_depth,) = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) @@ -2219,138 +1617,35 @@ class EventsStore( logger.info("[purge] done") - def _find_unreferenced_groups_during_purge(self, txn, state_groups): - """Used when purging history to figure out which state groups can be - deleted and which need to be de-delta'ed (due to one of its prev groups - being scheduled for deletion). - - Args: - txn - state_groups (set[int]): Set of state groups referenced by events - that are going to be deleted. - - Returns: - tuple[set[int], set[int]]: The set of state groups that can be - deleted and the set of state groups that need to be de-delta'ed - """ - # Graph of state group -> previous group - graph = {} - - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(itertools.islice(next_to_search, 100)) - next_to_search -= current_search - - # Check if state groups are referenced - sql = """ - SELECT DISTINCT state_group FROM event_to_state_groups - LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE ep.event_id IS NULL AND - """ - clause, args = make_in_list_sql_clause( - txn.database_engine, "state_group", current_search - ) - txn.execute(sql + clause, list(args)) - - referenced = set(sg for sg, in txn) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=current_search, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - ) - - prevs = set(row["state_group"] for row in rows) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - for row in rows: - # Note: Each state group can have at most one prev group - graph[row["state_group"]] = row["prev_state_group"] - - to_delete = state_groups_seen - referenced_groups - - to_dedelta = set() - for sg in referenced_groups: - prev_sg = graph.get(sg) - if prev_sg and prev_sg in to_delete: - to_dedelta.add(sg) - - return to_delete, to_dedelta + return referenced_state_groups def purge_room(self, room_id): """Deletes all record of a room Args: - room_id (str): + room_id (str) + + Returns: + Deferred[List[int]]: The list of state groups to delete. """ return self.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - + # First we fetch all the state groups that should be deleted, before + # we delete that information. txn.execute( """ - DELETE FROM state_groups_state WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? """, (room_id,), ) - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - txn.execute( - """ - DELETE FROM state_group_edges WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - txn.execute( - """ - DELETE FROM state_groups WHERE id IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) + state_groups = [row[0] for row in txn] - # and then tables which lack an index on room_id but have one on event_id + # Now we delete tables which lack an index on room_id but have one on event_id for table in ( "event_auth", "event_edges", @@ -2396,7 +1691,6 @@ class EventsStore( "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", - "topics", "users_in_public_rooms", "users_who_share_private_rooms", # no useful index, but let's clear them anyway @@ -2439,12 +1733,170 @@ class EventsStore( logger.info("[purge] done") + return state_groups + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self._simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self._simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self._simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self._simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self._simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self._simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) @@ -2477,6 +1929,33 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. + """ + return self._simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 31ea6f917f..0ed59ef48e 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore @@ -85,6 +86,10 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) + self.register_background_update_handler( + "event_store_labels", self._event_store_labels + ) + @defer.inlineCallbacks def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -438,7 +443,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not rows: return 0 - upper_event_id, = rows[-1] + (upper_event_id,) = rows[-1] # Update the redactions with the received_ts. # @@ -503,3 +508,61 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): yield self._end_background_update("event_fix_redactions_bytes") return 1 + + @defer.inlineCallbacks + def _event_store_labels(self, progress, batch_size): + """Background update handler which will store labels for existing events.""" + last_event_id = progress.get("last_event_id", "") + + def _event_store_labels_txn(txn): + txn.execute( + """ + SELECT event_id, json FROM event_json + LEFT JOIN event_labels USING (event_id) + WHERE event_id > ? AND label IS NULL + ORDER BY event_id LIMIT ? + """, + (last_event_id, batch_size), + ) + + results = list(txn) + + nbrows = 0 + last_row_event_id = "" + for (event_id, event_json_raw) in results: + event_json = json.loads(event_json_raw) + + self._simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": event_json["room_id"], + "topological_ordering": event_json["depth"], + } + for label in event_json["content"].get( + EventContentFields.LABELS, [] + ) + if isinstance(label, str) + ], + ) + + nbrows += 1 + last_row_event_id = event_id + + self._background_update_progress_txn( + txn, "event_store_labels", {"last_event_id": last_row_event_id} + ) + + return nbrows + + num_rows = yield self.runInteraction( + desc="event_store_labels", func=_event_store_labels_txn + ) + + if not num_rows: + yield self._end_background_update("event_store_labels") + + return num_rows diff --git a/synapse/storage/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 4c4b76bd93..4c4b76bd93 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py index 7c2a7da836..a2a2a67927 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -16,10 +16,9 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore, db_to_json - class FilteringStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py index 15b01c6958..5ded539af8 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -19,8 +19,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -250,7 +249,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -510,7 +509,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -554,6 +553,21 @@ class GroupServerStore(SQLBaseStore): desc="remove_user_from_summary", ) + def get_local_groups_for_room(self, room_id): + """Get all of the local group that contain a given room + Args: + room_id (str): The ID of a room + Returns: + Deferred[list[str]]: A twisted.Deferred containing a list of group ids + containing this room + """ + return self._simple_select_onecol( + table="group_rooms", + keyvalues={"room_id": room_id}, + retcol="group_id", + desc="get_local_groups_for_room", + ) + def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py new file mode 100644 index 0000000000..ebc7db3ed6 --- /dev/null +++ b/synapse/storage/data_stores/main/keys.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +import six + +from signedjson.key import decode_verify_key_bytes + +from synapse.storage._base import SQLBaseStore +from synapse.storage.keys import FetchKeyResult +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList + +logger = logging.getLogger(__name__) + +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + + +class KeyStore(SQLBaseStore): + """Persistence for signature verification keys + """ + + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ + Args: + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is + unknown + """ + keys = {} + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes, ts_valid_until_ms = row + + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, + ) + keys[(server_name, key_id)] = res + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. + Args: + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). + """ + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i,)) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) + + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): + """Stores the JSON bytes for a set of keys from a server + The JSON should be signed by the originating server, the intermediate + server, and by this server. Updates the value for the + (server_name, key_id, from_server) triplet if one already existed. + Args: + server_name (str): The name of the server. + key_id (str): The identifer of the key this JSON is for. + from_server (str): The server this JSON was fetched from. + ts_now_ms (int): The time now in milliseconds. + ts_valid_until_ms (int): The time when this json stops being valid. + key_json (bytes): The encoded JSON. + """ + return self._simple_upsert( + table="server_keys_json", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + }, + values={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + "ts_added_ms": ts_now_ms, + "ts_valid_until_ms": ts_expires_ms, + "key_json": db_binary_type(key_json_bytes), + }, + desc="store_server_keys_json", + ) + + def get_server_keys_json(self, server_keys): + """Retrive the key json for a list of server_keys and key ids. + If no keys are found for a given server, key_id and source then + that server, key_id, and source triplet entry will be an empty list. + The JSON is returned as a byte array so that it can be efficiently + used in an HTTP response. + Args: + server_keys (list): List of (server_name, key_id, source) triplets. + Returns: + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts + """ + + def _get_server_keys_json_txn(txn): + results = {} + for server_name, key_id, from_server in server_keys: + keyvalues = {"server_name": server_name} + if key_id is not None: + keyvalues["key_id"] = key_id + if from_server is not None: + keyvalues["from_server"] = from_server + rows = self._simple_select_list_txn( + txn, + "server_keys_json", + keyvalues=keyvalues, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + ) + results[(server_name, key_id, from_server)] = rows + return results + + return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 84b5f3ad5e..84b5f3ad5e 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 3803604be7..b41c3d317a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -16,10 +16,9 @@ import logging from twisted.internet import defer +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp @@ -172,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" txn.execute(sql) - count, = txn.fetchone() + (count,) = txn.fetchone() return count return self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py index b3318045ee..79b40044d9 100644 --- a/synapse/storage/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -1,4 +1,4 @@ -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py new file mode 100644 index 0000000000..523ed6575e --- /dev/null +++ b/synapse/storage/data_stores/main/presence.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.presence import UserPresenceState +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList + + +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) + ) + + with stream_ordering_manager as stream_orderings: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, + stream_orderings, + presence_states, + ) + + return stream_orderings[-1], self._presence_id_gen.get_current_token() + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, state.user_id, stream_id + ) + txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], + ) + + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + + def get_all_presence_updates(self, last_id, current_id): + if last_id == current_id: + return defer.succeed([]) + + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_presence_for_users(self, user_ids): + rows = yield self._simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + desc="get_presence_for_users", + ) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return {row["user_id"]: UserPresenceState(**row) for row in rows} + + def get_current_presence_token(self): + return self._presence_id_gen.get_current_token() + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_insert( + table="presence_allow_inbound", + values={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="allow_presence_visible", + or_ignore=True, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_delete_one( + table="presence_allow_inbound", + keyvalues={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="disallow_presence_visible", + ) diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py index 912c1df6be..e4e8a1c1d6 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -16,9 +16,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.storage.roommember import ProfileInfo - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py new file mode 100644 index 0000000000..b520062d84 --- /dev/null +++ b/synapse/storage/data_stores/main/push_rule.py @@ -0,0 +1,713 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.push.baserules import list_with_base_rules +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule["rule_id"] + if rule_id in enabled_map: + if rule.get("enabled", True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule["enabled"] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +class PushRulesWorkerStore( + ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore, +): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, + "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_for_user(self, user_id): + rows = yield self._simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_enabled_for_user", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + return rules + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_enabled_for_user(self, user_id): + results = yield self._simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("user_name", "rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ) + return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + (count,) = txn.fetchone() + return bool(count) + + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + @cachedList( + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules(self, user_ids): + if not user_ids: + return {} + + results = {user_id: [] for user_id in user_ids} + + rows = yield self._simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=("*",), + desc="bulk_get_push_rules", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + for row in rows: + results.setdefault(row["user_name"], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + + return results + + @defer.inlineCallbacks + def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + """Copy a single push rule from one room to another for a specific user. + + Args: + new_room_id (str): ID of the new room. + user_id (str): ID of user the push rule belongs to. + rule (Dict): A push rule. + """ + # Create new rule id + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + new_rule_id = rule_id_scope + "/" + new_room_id + + # Change room id in each condition + for condition in rule.get("conditions", []): + if condition.get("key") == "room_id": + condition["pattern"] = new_room_id + + # Add the rule for the new room + yield self.add_push_rule( + user_id=user_id, + rule_id=new_rule_id, + priority_class=rule["priority_class"], + conditions=rule["conditions"], + actions=rule["actions"], + ) + + @defer.inlineCallbacks + def copy_push_rules_from_room_to_room_for_user( + self, old_room_id, new_room_id, user_id + ): + """Copy all of the push rules from one room to another for a specific + user. + + Args: + old_room_id (str): ID of the old room. + new_room_id (str): ID of the new room. + user_id (str): ID of user to copy push rules for. + """ + # Retrieve push rules for this user + user_push_rules = yield self.get_push_rules_for_user(user_id) + + # Get rules relating to the old room and copy them to the new room + for rule in user_push_rules: + conditions = rule.get("conditions", []) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + + @defer.inlineCallbacks + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event + ) + return result + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room( + self, room_id, state_group, current_state_ids, cache_context, event=None + ): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + + users_in_room = yield self._get_joined_users_from_context( + room_id, + state_group, + current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, + ) + + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = set( + u + for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + ) + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, on_invalidate=cache_context.invalidate + ) + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + return rules_by_user + + @cachedList( + cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules_enabled(self, user_ids): + if not user_ids: + return {} + + results = {user_id: {} for user_id in user_ids} + + rows = yield self._simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled"), + desc="bulk_get_push_rules_enabled", + ) + for row in rows: + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled + return results + + +class PushRuleStore(PushRulesWorkerStore): + @defer.inlineCallbacks + def add_push_rule( + self, + user_id, + rule_id, + priority_class, + conditions, + actions, + before=None, + after=None, + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ) + else: + yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ) + + def _add_push_rule_relative_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + relative_to_rule = before or after + + res = self._simple_select_one_txn( + txn, + table="push_rules", + keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, + retcols=["priority_class", "priority"], + allow_none=True, + ) + + if not res: + raise RuleNotFoundException( + "before/after rule not found: %s" % (relative_to_rule,) + ) + + base_priority_class = res["priority_class"] + base_rule_priority = res["priority"] + + if base_priority_class != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 + else: + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority + + sql = ( + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" + ) + + txn.execute(sql, (user_id, priority_class, new_rule_priority)) + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_rule_priority, + conditions_json, + actions_json, + ) + + def _add_push_rule_highest_priority_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM push_rules" + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_id, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_prio, + conditions_json, + actions_json, + ) + + def _upsert_push_rule_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + conditions_json, + actions_json, + update_stream=True, + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute( + sql, + (priority_class, priority, conditions_json, actions_json, user_id, rule_id), + ) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + if update_stream: + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + @defer.inlineCallbacks + def delete_push_rule(self, user_id, rule_id): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_id (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self._simple_delete_one_txn( + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "delete_push_rule", + delete_push_rule_txn, + stream_id, + event_stream_ordering, + ) + + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_id, rule_id, enabled): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + ) + + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): + new_id = self._push_rules_enable_id_gen.get_next() + self._simple_upsert_txn( + txn, + "push_rules_enable", + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ENABLE" if enabled else "DISABLE", + ) + + @defer.inlineCallbacks + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + "[]", + actions_json, + update_stream=False, + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ACTIONS", + data={"actions": actions_json}, + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "set_push_rule_actions", + set_push_rule_actions_txn, + stream_id, + event_stream_ordering, + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) + txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_current_token() + + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py index b12e80440a..d76861cdc0 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -22,10 +22,9 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) if six.PY2: @@ -45,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore): r["data"] = json.loads(dataJson) except Exception as e: - logger.warn( + logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], dataJson, diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py index 0c24430f28..0c24430f28 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py index 6c5b29288a..ee1b2b2bbf 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE appservice_id IS NULL """ ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_users", _count_users) @@ -488,14 +488,14 @@ class RegistrationWorkerStore(SQLBaseStore): we can. Unfortunately, it's possible some of them are already taken by existing users, and there may be gaps in the already taken range. This function returns the start of the first allocatable gap. This is to - avoid the case of ID 10000000 being pre-allocated, so us wasting the - first (and shortest) many generated user IDs. + avoid the case of ID 1000 being pre-allocated and starting at 1001 while + 0-999 are available. """ def _find_next_generated_user_id(txn): - # We bound between '@1' and '@a' to avoid pulling the entire table + # We bound between '@0' and '@a' to avoid pulling the entire table # out. - txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'") + txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") regex = re.compile(r"^@(\d+):") diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py index f4c1c2a457..7d5de0ea2e 100644 --- a/synapse/storage/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -15,7 +15,7 @@ import logging -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py new file mode 100644 index 0000000000..858f65582b --- /dev/null +++ b/synapse/storage/data_stores/main/relations.py @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.stream import generate_pagination_where_clause +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + return edit_event + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) diff --git a/synapse/storage/room.py b/synapse/storage/data_stores/main/room.py index 43cc56fa6f..67bb1b6f60 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.search import SearchStore +from synapse.storage.data_stores.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -201,13 +201,17 @@ class RoomWorkerStore(SQLBaseStore): where_clauses.append( """ ( - name LIKE ? - OR topic LIKE ? - OR canonical_alias LIKE ? + LOWER(name) LIKE ? + OR LOWER(topic) LIKE ? + OR LOWER(canonical_alias) LIKE ? ) """ ) - query_args += [search_term, search_term, search_term] + query_args += [ + search_term.lower(), + search_term.lower(), + search_term.lower(), + ] where_clause = "" if where_clauses: diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py new file mode 100644 index 0000000000..2af24a20b7 --- /dev/null +++ b/synapse/storage/data_stores/main/roommember.py @@ -0,0 +1,1145 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems, itervalues + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.engines import Sqlite3Engine +from synapse.storage.roommember import ( + GetRoomsForUserWithStreamOrdering, + MemberSummary, + ProfileInfo, + RoomsForUser, +) +from synapse.types import get_domain_from_id +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure +from synapse.util.stringutils import to_ascii + +logger = logging.getLogger(__name__) + + +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" +_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" + + +class RoomMemberWorkerStore(EventsWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + + # Is the current_state_events.membership up to date? Or is the + # background update still running? + self._current_state_events_membership_up_to_date = False + + txn = LoggingTransaction( + db_conn.cursor(), + name="_check_safe_current_state_events_membership_updated", + database_engine=self.database_engine, + ) + self._check_safe_current_state_events_membership_updated_txn(txn) + txn.close() + + if self.hs.config.metrics_flags.known_servers: + self._known_servers_count = 1 + self.hs.get_clock().looping_call( + run_as_background_process, + 60 * 1000, + "_count_known_servers", + self._count_known_servers, + ) + self.hs.get_clock().call_later( + 1000, + run_as_background_process, + "_count_known_servers", + self._count_known_servers, + ) + LaterGauge( + "synapse_federation_known_servers", + "", + [], + lambda: self._known_servers_count, + ) + + @defer.inlineCallbacks + def _count_known_servers(self): + """ + Count the servers that this server knows about. + + The statistic is stored on the class for the + `synapse_federation_known_servers` LaterGauge to collect. + """ + + def _transact(txn): + if isinstance(self.database_engine, Sqlite3Engine): + query = """ + SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) + FROM ( + SELECT rm.user_id as user_id, instr(rm.user_id, ':') + AS pos FROM room_memberships as rm + INNER JOIN current_state_events as c ON rm.event_id = c.event_id + WHERE c.type = 'm.room.member' + ) as out + """ + else: + query = """ + SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) + FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join'; + """ + txn.execute(query) + return list(txn)[0][0] + + count = yield self.runInteraction("get_known_servers", _transact) + + # We always know about ourselves, even if we have nothing in + # room_memberships (for example, the server is new). + self._known_servers_count = max([count, 1]) + return self._known_servers_count + + def _check_safe_current_state_events_membership_updated_txn(self, txn): + """Checks if it is safe to assume the new current_state_events + membership column is up to date + """ + + pending_update = self._simple_select_one_txn( + txn, + table="background_updates", + keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, + retcols=["update_name"], + allow_none=True, + ) + + self._current_state_events_membership_up_to_date = not pending_update + + # If the update is still running, reschedule to run. + if pending_update: + self._clock.call_later( + 15.0, + run_as_background_process, + "_check_safe_current_state_events_membership_updated", + self.runInteraction, + "_check_safe_current_state_events_membership_updated", + self._check_safe_current_state_events_membership_updated_txn, + ) + + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) + def get_hosts_in_room(self, room_id, cache_context): + """Returns the set of all hosts currently in the room + """ + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) + return hosts + + @cached(max_entries=100000, iterable=True) + def get_users_in_room(self, room_id): + return self.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) + + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ + + txn.execute(sql, (room_id, Membership.JOIN)) + return [to_ascii(r[0]) for r in txn] + + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + else: + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT c.state_key, m.membership, c.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c USING (room_id, event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + c.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + + def _get_user_counts_in_room_txn(self, txn, room_id): + """ + Get the user count in a room by membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} + + @cached() + def get_invited_rooms_for_user(self, user_id): + """ Get all the rooms the user is invited to + Args: + user_id (str): The user ID. + Returns: + A deferred list of RoomsForUser. + """ + + return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) + + @defer.inlineCallbacks + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + return invite + return None + + @defer.inlineCallbacks + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this user where the membership for this user + matches one in the membership list. + + Filters out forgotten rooms. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + + Returns: + Deferred[list[RoomsForUser]] + """ + if not membership_list: + return defer.succeed(None) + + rooms = yield self.runInteraction( + "get_rooms_for_user_where_membership_is", + self._get_rooms_for_user_where_membership_is_txn, + user_id, + membership_list, + ) + + # Now we filter out forgotten rooms + forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + return [room for room in rooms if room.room_id not in forgotten_rooms] + + def _get_rooms_for_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): + + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] + + results = [] + if membership_list: + if self._current_state_events_membership_up_to_date: + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND %s + """ % ( + clause, + ) + else: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.membership", membership_list + ) + sql = """ + SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND %s + """ % ( + clause, + ) + + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM local_invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, (user_id,)) + results.extend( + RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) + for r in self.cursor_to_dict(txn) + ) + + return results + + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to + + Args: + user_id (str) + + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + rooms = yield self.get_rooms_for_user_where_membership_is( + user_id, membership_list=[Membership.JOIN] + ) + return frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + ) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate + ) + return frozenset(r.room_id for r in rooms) + + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + room_ids = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + + user_who_share_room = set() + for room_id in room_ids: + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + user_who_share_room.update(user_ids) + + return user_who_share_room + + @defer.inlineCallbacks + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context + ) + return result + + @defer.inlineCallbacks + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) + + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in iteritems(current_state_ids) + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in iteritems(context.delta_ids) + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, allow_rejected=False, update_metrics=False + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids + ) + users_in_room.update((row for row in event_to_memberships.values() if row)) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[to_ascii(event.state_key)] = ProfileInfo( + display_name=to_ascii(event.content.get("displayname", None)), + avatar_url=to_ascii(event.content.get("avatar_url", None)), + ) + + return users_in_room + + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE m.membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @defer.inlineCallbacks + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = yield self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + return joined_hosts + + @cached(max_entries=10000) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + + count = yield self.runInteraction("did_forget_membership", f) + return count == 0 + + @cached() + def get_forgotten_rooms_for_user(self, user_id): + """Gets all rooms the user has forgotten. + + Args: + user_id (str) + + Returns: + Deferred[set[str]] + """ + + def _get_forgotten_rooms_for_user_txn(txn): + # This is a slightly convoluted query that first looks up all rooms + # that the user has forgotten in the past, then rechecks that list + # to see if any have subsequently been updated. This is done so that + # we can use a partial index on `forgotten = 1` on the assumption + # that few users will actually forget many rooms. + # + # Note that a room is considered "forgotten" if *all* membership + # events for that user and room have the forgotten field set (as + # when a user forgets a room we update all rows for that user and + # room, not just the current one). + sql = """ + SELECT room_id, ( + SELECT count(*) FROM room_memberships + WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 + ) AS count + FROM room_memberships AS m + WHERE user_id = ? AND forgotten = 1 + GROUP BY room_id, user_id; + """ + txn.execute(sql, (user_id,)) + return set(row[0] for row in txn if row[1] == 0) + + return self.runInteraction( + "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn + ) + + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self._simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + + +class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + self.register_background_update_handler( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + self._background_current_state_membership, + ) + self.register_background_index_update( + "room_membership_forgotten_idx", + index_name="room_memberships_user_room_forgotten", + table="room_memberships", + columns=["user_id", "room_id"], + where_clause="forgotten = 1", + ) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = """ + SELECT stream_ordering, event_id, events.room_id, event_json.json + FROM events + INNER JOIN event_json USING (event_id) + INNER JOIN room_memberships USING (event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """ + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append((display_name, avatar_url, event_id, room_id)) + + to_update_sql = """ + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """ + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + + return result + + @defer.inlineCallbacks + def _background_current_state_membership(self, progress, batch_size): + """Update the new membership column on current_state_events. + + This works by iterating over all rooms in alphebetical order. + """ + + def _background_current_state_membership_txn(txn, last_processed_room): + processed = 0 + while processed < batch_size: + txn.execute( + """ + SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? + """, + (last_processed_room,), + ) + row = txn.fetchone() + if not row or not row[0]: + return processed, True + + (next_room,) = row + + sql = """ + UPDATE current_state_events + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = current_state_events.event_id + ) + WHERE room_id = ? + """ + txn.execute(sql, (next_room,)) + processed += txn.rowcount + + last_processed_room = next_room + + self._background_update_progress_txn( + txn, + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + {"last_processed_room": last_processed_room}, + ) + + return processed, False + + # If we haven't got a last processed room then just use the empty + # string, which will compare before all room IDs correctly. + last_processed_room = progress.get("last_processed_room", "") + + row_count, finished = yield self.runInteraction( + "_background_current_state_membership_update", + _background_current_state_membership_txn, + last_processed_room, + ) + + if finished: + yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + + return row_count + + +class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) + + return self.runInteraction("forget_membership", f) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + return frozenset(self.hosts_to_joined_users) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in iteritems(state_entry.delta_ids): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql index 5964c5aaac..5964c5aaac 100644 --- a/synapse/storage/schema/delta/12/v12.sql +++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql index f8649e5d99..f8649e5d99 100644 --- a/synapse/storage/schema/delta/13/v13.sql +++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql index a831920da6..a831920da6 100644 --- a/synapse/storage/schema/delta/14/v14.sql +++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql index e4f5e76aec..e4f5e76aec 100644 --- a/synapse/storage/schema/delta/15/appservice_txns.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql index 6b8d0f1ca7..6b8d0f1ca7 100644 --- a/synapse/storage/schema/delta/15/presence_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql index 9523d2bcc3..9523d2bcc3 100644 --- a/synapse/storage/schema/delta/15/v15.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql index a48f215170..a48f215170 100644 --- a/synapse/storage/schema/delta/16/events_order_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql index 7a15265cb1..7a15265cb1 100644 --- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql index 65c97b5e2f..65c97b5e2f 100644 --- a/synapse/storage/schema/delta/16/remove_duplicates.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql index f82486132b..f82486132b 100644 --- a/synapse/storage/schema/delta/16/room_alias_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql index 5b8de52c33..5b8de52c33 100644 --- a/synapse/storage/schema/delta/16/unique_constraints.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql index cd0709250d..cd0709250d 100644 --- a/synapse/storage/schema/delta/16/users.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql index 7c9a90e27f..7c9a90e27f 100644 --- a/synapse/storage/schema/delta/17/drop_indexes.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql index 70b247a06b..70b247a06b 100644 --- a/synapse/storage/schema/delta/17/server_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql index c17715ac80..c17715ac80 100644 --- a/synapse/storage/schema/delta/17/user_threepids.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql index 6e0871c92b..6e0871c92b 100644 --- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql +++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql index 18b97b4332..18b97b4332 100644 --- a/synapse/storage/schema/delta/19/event_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql index e0ac49d1ec..e0ac49d1ec 100644 --- a/synapse/storage/schema/delta/20/dummy.sql +++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py index 3edfcfd783..3edfcfd783 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql index 4c2fb20b77..4c2fb20b77 100644 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql index d070845477..d070845477 100644 --- a/synapse/storage/schema/delta/21/receipts.sql +++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql index bfc0b3bcaa..bfc0b3bcaa 100644 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql index 87edfa454c..87edfa454c 100644 --- a/synapse/storage/schema/delta/22/user_threepids_unique.sql +++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql index acea7483bd..acea7483bd 100644 --- a/synapse/storage/schema/delta/24/stats_reporting.sql +++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py index 4b2ffd35fd..4b2ffd35fd 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql index 1ea389b471..1ea389b471 100644 --- a/synapse/storage/schema/delta/25/guest_access.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql index f468fc1897..f468fc1897 100644 --- a/synapse/storage/schema/delta/25/history_visibility.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql index 7a32ce68e4..7a32ce68e4 100644 --- a/synapse/storage/schema/delta/25/tags.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql index e395de2b5e..e395de2b5e 100644 --- a/synapse/storage/schema/delta/26/account_data.sql +++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql index bf0558b5b3..bf0558b5b3 100644 --- a/synapse/storage/schema/delta/27/account_data.sql +++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql index e2094f37fe..e2094f37fe 100644 --- a/synapse/storage/schema/delta/27/forgotten_memberships.sql +++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py index 414f9f5aa0..414f9f5aa0 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql index 4d519849df..4d519849df 100644 --- a/synapse/storage/schema/delta/28/event_push_actions.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql index 36609475f1..36609475f1 100644 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql index 6c1fd68c5b..6c1fd68c5b 100644 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql index cb84c69baa..cb84c69baa 100644 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql index 3e4a9ab455..3e4a9ab455 100644 --- a/synapse/storage/schema/delta/28/upgrade_times.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql index 21d2b420bf..21d2b420bf 100644 --- a/synapse/storage/schema/delta/28/users_is_guest.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql index 84b21cf813..84b21cf813 100644 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql index c9d0dde638..c9d0dde638 100644 --- a/synapse/storage/schema/delta/30/alias_creator.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py index 9b95411fb6..9b95411fb6 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql index 712c454aa1..712c454aa1 100644 --- a/synapse/storage/schema/delta/30/deleted_pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql index 606bbb037d..606bbb037d 100644 --- a/synapse/storage/schema/delta/30/presence_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql index f09db4faa6..f09db4faa6 100644 --- a/synapse/storage/schema/delta/30/public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql index 735aa8d5f6..735aa8d5f6 100644 --- a/synapse/storage/schema/delta/30/push_rule_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql index e85699e82e..e85699e82e 100644 --- a/synapse/storage/schema/delta/30/state_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql index 0dd2f1360c..0dd2f1360c 100644 --- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql index 2c57846d5a..2c57846d5a 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql index 9efb4280eb..9efb4280eb 100644 --- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py index 9bb504aad5..9bb504aad5 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql index a82add88fd..a82add88fd 100644 --- a/synapse/storage/schema/delta/31/pushers_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py index 7d8ca5f93f..7d8ca5f93f 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql index 1dd0f9e170..1dd0f9e170 100644 --- a/synapse/storage/schema/delta/32/events.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql index 36f37b11c8..36f37b11c8 100644 --- a/synapse/storage/schema/delta/32/openid.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql index d86d30c13c..d86d30c13c 100644 --- a/synapse/storage/schema/delta/32/pusher_throttle.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..4219cdd06a 100644 --- a/synapse/storage/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql index d13609776f..d13609776f 100644 --- a/synapse/storage/schema/delta/32/reports.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql index 61ad3fe3e8..61ad3fe3e8 100644 --- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql index eca7268d82..eca7268d82 100644 --- a/synapse/storage/schema/delta/33/devices.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql index aa4a3b9f2f..aa4a3b9f2f 100644 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql index 6671573398..6671573398 100644 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py index bff1256a7b..bff1256a7b 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py index a26057dfb6..a26057dfb6 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql index 473f75a78e..473f75a78e 100644 --- a/synapse/storage/schema/delta/33/user_ips_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql index 69e16eda0f..69e16eda0f 100644 --- a/synapse/storage/schema/delta/34/appservice_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py index cf09e43e2b..cf09e43e2b 100644 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql index e68844c74a..e68844c74a 100644 --- a/synapse/storage/schema/delta/34/device_inbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql index 0d9fe1a99a..0d9fe1a99a 100644 --- a/synapse/storage/schema/delta/34/push_display_name_rename.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py index 67d505e68b..67d505e68b 100644 --- a/synapse/storage/schema/delta/34/received_txn_purge.py +++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql index 0fce26345b..33980d02f0 100644 --- a/synapse/storage/schema/delta/35/add_state_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql @@ -13,8 +13,5 @@ * limitations under the License. */ - -ALTER TABLE background_updates ADD COLUMN depends_on TEXT; - INSERT into background_updates (update_name, progress_json, depends_on) VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql index 6cd123027b..6cd123027b 100644 --- a/synapse/storage/schema/delta/35/contains_url.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql index 17e6c43105..17e6c43105 100644 --- a/synapse/storage/schema/delta/35/device_outbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql index 7ab7d942e2..7ab7d942e2 100644 --- a/synapse/storage/schema/delta/35/device_stream_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql index 2e836d8e9c..2e836d8e9c 100644 --- a/synapse/storage/schema/delta/35/event_push_actions_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql index dd2bf2e28a..dd2bf2e28a 100644 --- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/main/schema/delta/35/state.sql index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/schema/delta/35/state.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/state.sql diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql index 2b945d8a57..2b945d8a57 100644 --- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql index 90d8fd18f9..90d8fd18f9 100644 --- a/synapse/storage/schema/delta/36/readd_public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py index a377884169..a377884169 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql index cf7a90dd10..cf7a90dd10 100644 --- a/synapse/storage/schema/delta/37/user_threepids.sql +++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql index 515e6b8e84..515e6b8e84 100644 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql index 74bdc49073..74bdc49073 100644 --- a/synapse/storage/schema/delta/39/appservice_room_list.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql index 00be801e90..00be801e90 100644 --- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql index de2ad93e5c..de2ad93e5c 100644 --- a/synapse/storage/schema/delta/39/event_push_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql index 5af814290b..5af814290b 100644 --- a/synapse/storage/schema/delta/39/federation_out_position.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql index 1bf911c8ab..1bf911c8ab 100644 --- a/synapse/storage/schema/delta/39/membership_profile.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql index 7ffa189f39..7ffa189f39 100644 --- a/synapse/storage/schema/delta/40/current_state_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql index b9fe1f0480..b9fe1f0480 100644 --- a/synapse/storage/schema/delta/40/device_inbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql index dd6dcb65f1..dd6dcb65f1 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql index 3918f0b794..3918f0b794 100644 --- a/synapse/storage/schema/delta/40/event_push_summary.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql index 054a223f14..054a223f14 100644 --- a/synapse/storage/schema/delta/40/pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql index b7bee8b692..b7bee8b692 100644 --- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql index 62f0b9892b..62f0b9892b 100644 --- a/synapse/storage/schema/delta/41/device_outbound_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql index 5d9cfecf36..5d9cfecf36 100644 --- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql index a194bf0238..a194bf0238 100644 --- a/synapse/storage/schema/delta/41/ratelimit.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql index d28851aff8..d28851aff8 100644 --- a/synapse/storage/schema/delta/42/current_state_delta.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql index 9ab8c14fa3..9ab8c14fa3 100644 --- a/synapse/storage/schema/delta/42/device_list_last_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql index b8821ac759..b8821ac759 100644 --- a/synapse/storage/schema/delta/42/event_auth_state_only.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py index 506f326f4d..506f326f4d 100644 --- a/synapse/storage/schema/delta/42/user_dir.py +++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql index 0e3cd143ff..0e3cd143ff 100644 --- a/synapse/storage/schema/delta/43/blocked_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql index 630907ec4f..630907ec4f 100644 --- a/synapse/storage/schema/delta/43/quarantine_media.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql index 45ebe020da..45ebe020da 100644 --- a/synapse/storage/schema/delta/43/url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql index ee7062abe4..ee7062abe4 100644 --- a/synapse/storage/schema/delta/43/user_share.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql index b12f9b2ebf..b12f9b2ebf 100644 --- a/synapse/storage/schema/delta/44/expire_url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql index b2333848a0..b2333848a0 100644 --- a/synapse/storage/schema/delta/45/group_server.sql +++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql index e5ddc84df0..e5ddc84df0 100644 --- a/synapse/storage/schema/delta/45/profile_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql index 68c48a89a9..68c48a89a9 100644 --- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql index bb307889c1..bb307889c1 100644 --- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql index 097679bc9a..097679bc9a 100644 --- a/synapse/storage/schema/delta/46/group_server.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql index bbfc7f5d1a..bbfc7f5d1a 100644 --- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql index cb0d5a2576..cb0d5a2576 100644 --- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql index d9505f8da1..d9505f8da1 100644 --- a/synapse/storage/schema/delta/46/user_dir_typos.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql index f505fb22b5..f505fb22b5 100644 --- a/synapse/storage/schema/delta/47/last_access_media.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql index 31d7a817eb..31d7a817eb 100644 --- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql index edccf4a96f..edccf4a96f 100644 --- a/synapse/storage/schema/delta/47/push_actions_staging.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql index 5237491506..5237491506 100644 --- a/synapse/storage/schema/delta/48/add_user_consent.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql index 9248b0b24a..9248b0b24a 100644 --- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql index e9013a6969..e9013a6969 100644 --- a/synapse/storage/schema/delta/48/deactivated_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py index 49f5f2c003..49f5f2c003 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql index ce26eaf0c9..ce26eaf0c9 100644 --- a/synapse/storage/schema/delta/48/groups_joinable.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql index 14dcf18d73..14dcf18d73 100644 --- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql index 3dd478196f..3dd478196f 100644 --- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql index 3a4ed59b5b..3a4ed59b5b 100644 --- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql index c93ae47532..c93ae47532 100644 --- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql index 5d8641a9ab..5d8641a9ab 100644 --- a/synapse/storage/schema/delta/50/erasure_store.sql +++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py index b1684a8441..b1684a8441 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql index c0e66a697d..c0e66a697d 100644 --- a/synapse/storage/schema/delta/51/e2e_room_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql index c9d537d5a3..c9d537d5a3 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql index 91e03d13e1..91e03d13e1 100644 --- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql index bfa49e6f92..bfa49e6f92 100644 --- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql index db687cccae..db687cccae 100644 --- a/synapse/storage/schema/delta/52/e2e_room_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql index 88ec2f83e5..88ec2f83e5 100644 --- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql index e372f5a44a..e372f5a44a 100644 --- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql index 1d977c2834..1d977c2834 100644 --- a/synapse/storage/schema/delta/53/event_format_version.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql index ffcc896b58..ffcc896b58 100644 --- a/synapse/storage/schema/delta/53/user_dir_populate.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql index b812c5794f..b812c5794f 100644 --- a/synapse/storage/schema/delta/53/user_ips_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql index 5831b1a6f8..5831b1a6f8 100644 --- a/synapse/storage/schema/delta/53/user_share.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql index 80c2c573b6..80c2c573b6 100644 --- a/synapse/storage/schema/delta/53/user_threepid_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql index f7827ca6d2..f7827ca6d2 100644 --- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql index 0adb2ad55e..0adb2ad55e 100644 --- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql index c01aa9d2d9..c01aa9d2d9 100644 --- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql index b062ec840c..b062ec840c 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql index dbbe682697..dbbe682697 100644 --- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql index e6ee70c623..e6ee70c623 100644 --- a/synapse/storage/schema/delta/54/drop_presence_list.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql index 134862b870..134862b870 100644 --- a/synapse/storage/schema/delta/54/relations.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql index 652e58308e..652e58308e 100644 --- a/synapse/storage/schema/delta/54/stats.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql index 3b2d48447f..3b2d48447f 100644 --- a/synapse/storage/schema/delta/54/stats2.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql index 4590604bfd..4590604bfd 100644 --- a/synapse/storage/schema/delta/55/access_token_expiry.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql index a8eced2e0a..a8eced2e0a 100644 --- a/synapse/storage/schema/delta/55/track_threepid_validations.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql index dabdde489b..dabdde489b 100644 --- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql index 41807eb1e7..41807eb1e7 100644 --- a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql index 473018676f..473018676f 100644 --- a/synapse/storage/schema/delta/56/current_state_events_membership.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql index 3133d42d4a..3133d42d4a 100644 --- a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql index f00889290b..f00889290b 100644 --- a/synapse/storage/schema/delta/56/destinations_failure_ts.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql diff --git a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres index b9bbb18a91..b9bbb18a91 100644 --- a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres diff --git a/synapse/storage/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql index dfa902d0ba..dfa902d0ba 100644 --- a/synapse/storage/schema/delta/56/devices_last_seen.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql index 9f09922c67..9f09922c67 100644 --- a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql new file mode 100644 index 0000000000..5e29c1da19 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. +CREATE TABLE IF NOT EXISTS event_labels ( + event_id TEXT, + label TEXT, + room_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + PRIMARY KEY(event_id, label) +); + + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql new file mode 100644 index 0000000000..5f5e0499ae --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_store_labels', '{}'); diff --git a/synapse/storage/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql index 014cb3b538..014cb3b538 100644 --- a/synapse/storage/schema/delta/56/fix_room_keys_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql diff --git a/synapse/storage/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql index 67f8b20297..67f8b20297 100644 --- a/synapse/storage/schema/delta/56/hidden_devices.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite new file mode 100644 index 0000000000..e8b1fd35d8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite @@ -0,0 +1,42 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change the hidden column from a default value of FALSE to a default value of + * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the + * string 'FALSE', which is truthy. + * + * Since sqlite doesn't allow us to just change the default value, we have to + * recreate the table, copy the data, fix the rows that have incorrect data, and + * replace the old table with the new table. + */ + +CREATE TABLE IF NOT EXISTS devices2 ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); + +INSERT INTO devices2 SELECT * FROM devices; + +UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; + +DROP TABLE devices; + +ALTER TABLE devices2 RENAME TO devices; diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql index 7be31ffebb..7be31ffebb 100644 --- a/synapse/storage/schema/delta/56/public_room_list_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql index fe51b02309..fe51b02309 100644 --- a/synapse/storage/schema/delta/56/redaction_censor.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql diff --git a/synapse/storage/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql index 77a5eca499..77a5eca499 100644 --- a/synapse/storage/schema/delta/56/redaction_censor2.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql diff --git a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres index 67471f3ef5..67471f3ef5 100644 --- a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql index 92ab1f5e65..92ab1f5e65 100644 --- a/synapse/storage/schema/delta/56/room_membership_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql index 27a96123e3..27a96123e3 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql index 163529c071..163529c071 100644 --- a/synapse/storage/schema/delta/56/stats_separated.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..1de8b54961 100644 --- a/synapse/storage/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql index 91390c4527..91390c4527 100644 --- a/synapse/storage/schema/delta/56/user_external_ids.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql diff --git a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql index 149f8be8b6..149f8be8b6 100644 --- a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql index 883fcd10b2..883fcd10b2 100644 --- a/synapse/storage/schema/full_schemas/16/application_services.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql index 10ce2aa7a0..10ce2aa7a0 100644 --- a/synapse/storage/schema/full_schemas/16/event_edges.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql index 95826da431..95826da431 100644 --- a/synapse/storage/schema/full_schemas/16/event_signatures.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql index a1a2aa8e5b..a1a2aa8e5b 100644 --- a/synapse/storage/schema/full_schemas/16/im.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql index 11cdffdbb3..11cdffdbb3 100644 --- a/synapse/storage/schema/full_schemas/16/keys.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql index 8f3759bb2a..8f3759bb2a 100644 --- a/synapse/storage/schema/full_schemas/16/media_repository.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql index 01d2d8f833..01d2d8f833 100644 --- a/synapse/storage/schema/full_schemas/16/presence.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql index c04f4747d9..c04f4747d9 100644 --- a/synapse/storage/schema/full_schemas/16/profiles.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql index e44465cf45..e44465cf45 100644 --- a/synapse/storage/schema/full_schemas/16/push.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql index 318f0d9aa5..318f0d9aa5 100644 --- a/synapse/storage/schema/full_schemas/16/redactions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql index d47da3b12f..d47da3b12f 100644 --- a/synapse/storage/schema/full_schemas/16/room_aliases.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql index 96391a8f0e..96391a8f0e 100644 --- a/synapse/storage/schema/full_schemas/16/state.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql index 17e67bedac..17e67bedac 100644 --- a/synapse/storage/schema/full_schemas/16/transactions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql index f013aa8b18..f013aa8b18 100644 --- a/synapse/storage/schema/full_schemas/16/users.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 098434356f..4ad2929f32 100644 --- a/synapse/storage/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position ( ); - -CREATE TABLE background_updates ( - update_name text NOT NULL, - progress_json text NOT NULL, - depends_on text -); - - - CREATE TABLE blocked_rooms ( room_id text NOT NULL, user_id text NOT NULL @@ -1202,11 +1193,6 @@ ALTER TABLE ONLY appservice_stream_position -ALTER TABLE ONLY background_updates - ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name); - - - ALTER TABLE ONLY current_state_events ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); @@ -2047,6 +2033,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); - - - diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index be9295e4c9..bad33291e7 100644 --- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -67,7 +67,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); CREATE INDEX user_threepids_user_id ON user_threepids(user_id); -CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) ); CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) /* event_search(event_id,room_id,sender,"key",value) */; CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql index c265fd20e2..c265fd20e2 100644 --- a/synapse/storage/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt index d3f6401344..d3f6401344 100644 --- a/synapse/storage/schema/full_schemas/README.txt +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.txt diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py index 7695bf09fc..d1d7c6863d 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,10 +25,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from .background_updates import BackgroundUpdateStore - logger = logging.getLogger(__name__) SearchEntry = namedtuple( @@ -197,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): " ON event_search USING GIN (vector)" ) except psycopg2.ProgrammingError as e: - logger.warn( + logger.warning( "Ignoring error %r when trying to switch from GIST to GIN", e ) @@ -673,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore): ) ) txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] + (headline,) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py index fb83218f90..556191b76f 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -20,10 +20,9 @@ from unpaddedbase64 import encode_base64 from twisted.internet import defer from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList -from ._base import SQLBaseStore - # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py new file mode 100644 index 0000000000..6a90daea31 --- /dev/null +++ b/synapse/storage/data_stores/main/state.py @@ -0,0 +1,1278 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple +from typing import Iterable, Tuple + +from six import iteritems, itervalues +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.storage._base import SQLBaseStore +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore( + EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore +): + """The parts of StateGroupStore that can be called from workers. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @defer.inlineCallbacks + def get_room_version(self, room_id): + """Get the room_version of a given room + + Args: + room_id (str) + + Returns: + Deferred[str] + + Raises: + NotFoundError if the room is unknown + """ + # for now we do this by looking at the create event. We may want to cache this + # more intelligently in future. + + # Retrieve the room's create event + create_event = yield self.get_create_event_for_room(room_id) + return create_event.content.get("room_version", "1") + + @defer.inlineCallbacks + def get_room_predecessor(self, room_id): + """Get the predecessor room of an upgraded room if one exists. + Otherwise return None. + + Args: + room_id (str) + + Returns: + Deferred[dict|None]: A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room + + Raises: + NotFoundError if the room is unknown + """ + # Retrieve the room's create event + create_event = yield self.get_create_event_for_room(room_id) + + # Return predecessor if present + return create_event.content.get("predecessor", None) + + @defer.inlineCallbacks + def get_create_event_for_room(self, room_id): + """Get the create state event for a room. + + Args: + room_id (str) + + Returns: + Deferred[EventBase]: The room creation event. + + Raises: + NotFoundError if the room is unknown + """ + state_ids = yield self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + # If we can't find the create event, assume we've hit a dead end + if not create_id: + raise NotFoundError("Unknown room %s" % (room_id)) + + # Retrieve the room's create event and return + create_event = yield self.get_event(create_id) + return create_event + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + return { + (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn + } + + return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + + Args: + room_id (str) + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + + return results + + return self.runInteraction( + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn + ) + + @defer.inlineCallbacks + def get_canonical_alias_for_room(self, room_id): + """Get canonical alias for room, if any + + Args: + room_id (str) + + Returns: + Deferred[str|None]: The canonical alias, if any + """ + + state = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return + + event = yield self.get_event(event_id, allow_none=True) + if not event: + return + + return event.content.get("canonical_alias") + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + + @defer.inlineCallbacks + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + if not event_ids: + return {} + + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups) + + return group_to_state + + @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the event IDs of all the state in the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + return group_to_state[state_group] + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id + for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + @defer.inlineCallbacks + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids (list[string]) + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + """ + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups, state_filter) + + state_event_map = yield self.get_events( + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in iteritems(group_to_state[group]) + if v in state_event_map + } + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids(list(str)): events whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from event_id -> (type, state_key) -> event_id + """ + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups, state_filter) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], state_filter) + return state_map[event_id] + + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], state_filter) + return state_map[event_id] + + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) + def _get_state_group_for_events(self, event_ids): + """Returns mapping event_id -> state_group + """ + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group"), + desc="_get_state_group_for_events", + ) + + return {row["event_id"]: row["state_group"] for row in rows} + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + @defer.inlineCallbacks + def get_referenced_state_groups(self, state_groups): + """Check if the state groups are referenced by events. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[set[int]]: The subset of state groups that are + referenced. + """ + + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("DISTINCT state_group",), + desc="get_referenced_state_groups", + ) + + return set(row["state_group"] for row in rows) + + +class StateBackgroundUpdateStore( + StateGroupBackgroundUpdateStore, BackgroundUpdateStore +): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" + + def __init__(self, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + self.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self._execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 + + +class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.state_group_before_event + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self._get_state_group_for_event.prefill, (event_id,), state_group_id + ) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18f..28f33ec18f 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py diff --git a/synapse/storage/stats.py b/synapse/storage/data_stores/main/stats.py index 7c224cd3d9..45b3de7d56 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -21,8 +21,8 @@ from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage import PostgresEngine -from synapse.storage.state_deltas import StateDeltasStore +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -332,7 +332,7 @@ class StatsStore(StateDeltasStore): def _bulk_update_stats_delta_txn(txn): for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): - logger.info( + logger.debug( "Updating %s stats for %s: %s", stats_type, stats_id, fields ) self._update_stats_delta_txn( @@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore): (room_id,), ) - current_state_events_count, = txn.fetchone() + (current_state_events_count,) = txn.fetchone() users_in_room = self.get_users_in_room_txn(txn, room_id) @@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count, pos joined_rooms, pos = yield self.runInteraction( diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py index 490454f19a..8780fdd989 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -43,8 +43,8 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -229,6 +229,14 @@ def filter_to_clause(event_filter): clauses.append("contains_url = ?") args.append(event_filter.contains_url) + # We're only applying the "labels" filter on the database query, because applying the + # "not_labels" filter via a SQL query is non-trivial. Instead, we let + # event_filter.check_fields apply it, which is not as efficient but makes the + # implementation simpler. + if event_filter.labels: + clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) + args.extend(event_filter.labels) + return " AND ".join(clauses), args @@ -863,13 +871,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.append(int(limit)) - sql = ( - "SELECT event_id, topological_ordering, stream_ordering" - " FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s LIMIT ?" - ) % {"bounds": bounds, "order": order} + select_keywords = "SELECT" + join_clause = "" + if event_filter and event_filter.labels: + # If we're not filtering on a label, then joining on event_labels will + # return as many row for a single event as the number of labels it has. To + # avoid this, only join if we're filtering on at least one label. + join_clause = """ + LEFT JOIN event_labels + USING (event_id, room_id, topological_ordering) + """ + if len(event_filter.labels) > 1: + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. We only need to use it when we're + # filtering on more than two labels, because that's the only scenario + # in which we can possibly to get multiple times the same event ID in + # the results. + select_keywords += "DISTINCT" + + sql = """ + %(select_keywords)s event_id, topological_ordering, stream_ordering + FROM events + %(join_clause)s + WHERE outlier = ? AND room_id = ? AND %(bounds)s + ORDER BY topological_ordering %(order)s, + stream_ordering %(order)s LIMIT ? + """ % { + "select_keywords": select_keywords, + "join_clause": join_clause, + "bounds": bounds, + "order": order, + } txn.execute(sql, args) diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py index 20dd6bd53d..10d1887f75 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py index 289c117396..01b1be5e14 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -23,10 +23,9 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.expiringcache import ExpiringCache -from ._base import SQLBaseStore, db_to_json - # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 1b1e4751b9..652abe0e6a 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -20,9 +20,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.state import StateFilter +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.state import StateFilter -from synapse.storage.state_deltas import StateDeltasStore from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f0..aa4f0da5f0 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e72f89e446..4769b21529 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -14,208 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging -import six - import attr -from signedjson.key import decode_verify_key_bytes - -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList - -from ._base import SQLBaseStore logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - @attr.s(slots=True, frozen=True) class FetchKeyResult(object): verify_key = attr.ib() # VerifyKey: the key itself valid_until_ts = attr.ib() # int: how long we can use this key for - - -class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ - - @cached() - def _get_server_verify_key(self, server_name_and_key_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" - ) - def get_server_verify_keys(self, server_name_and_key_ids): - """ - Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown - """ - keys = {} - - def _get_keys(txn, batch): - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = ( - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " - "FROM server_signature_keys WHERE 1=0" - ) + " OR (server_name=? AND key_id=?)" * len(batch) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - res = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - keys[(server_name, key_id)] = res - - def _txn(txn): - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return self.runInteraction("get_server_verify_keys", _txn) - - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): - """Stores NACL verification keys for remote servers. - Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). - """ - key_values = [] - value_values = [] - invalidations = [] - for server_name, key_id, fetch_result in verify_keys: - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_verify_key. _get_server_verify_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.runInteraction( - "store_server_verify_keys", - self._simple_upsert_many_txn, - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - ).addCallback(_invalidate) - - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. - """ - return self._simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) - - def get_server_keys_json(self, server_keys): - """Retrive the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. - Args: - server_keys (list): List of (server_name, key_id, source) triplets. - Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts - """ - - def _get_server_keys_json_txn(txn): - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results - - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py new file mode 100644 index 0000000000..fa03ca9ff7 --- /dev/null +++ b/synapse/storage/persist_events.py @@ -0,0 +1,649 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import deque, namedtuple + +from six import iteritems +from six.moves import range + +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore +from synapse.storage.data_stores import DataStores +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + + +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled): + """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + end_item = queue[-1] + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + @defer.inlineCallbacks + def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = yield per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + + def __init__(self, hs, stores: DataStores): + # We ultimately want to split out the state store from the main store, + # so we use separate variables here even though they point to the same + # store for now. + self.main_store = stores.main + self.state_store = stores.main + + self._clock = hs.get_clock() + self.is_mine_id = hs.is_mine_id + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + + @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + deferreds.append(d) + + for room_id in partitioned: + self._maybe_start_persisting(room_id) + + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + max_persisted_id = yield self.main_store.get_current_events_token() + + return max_persisted_id + + @defer.inlineCallbacks + def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + + self._maybe_start_persisting(event.room_id) + + yield make_deferred_yieldable(deferred) + + max_persisted_id = yield self.main_store.get_current_events_token() + return (event.internal_metadata.stream_ordering, max_persisted_id) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(item): + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, backfilled=item.backfilled + ) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @defer.inlineCallbacks + def _persist_events(self, events_and_contexts, backfilled=False): + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ + if not events_and_contexts: + return + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.main_store.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = yield self._calculate_state_delta( + room_id, current_state + ) + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + yield self.main_store._persist_events_and_state_updates( + chunk, + current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + backfilled=backfilled, + ) + + @defer.inlineCallbacks + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs = yield self.main_store._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self.main_store._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + @defer.inlineCallbacks + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. + + Returns: + Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = yield self.main_store._get_state_group_for_events( + missing_event_ids + ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + return new_state, delta_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self.state_store._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.main_store.get_room_version(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state, None + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + tuple[list, dict] (to_delete, to_insert): where to_delete are the + type/state_keys to remove from current_state_events and `to_insert` + are the updates to current_state_events. + """ + existing_state = yield self.main_store.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } + + return to_delete, to_insert diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e96eed8a6d..2e7753820e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import fnmatch import imp import logging import os import re +import attr + from synapse.storage.engines.postgres import PostgresEngine logger = logging.getLogger(__name__) @@ -54,6 +55,10 @@ def prepare_database(db_conn, database_engine, config): application config, or None if we are connecting to an existing database which we expect to be configured already """ + + # For now we only have the one datastore. + data_stores = ["main"] + try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) @@ -68,10 +73,16 @@ def prepare_database(db_conn, database_engine, config): raise UpgradeDatabaseException("Database needs to be upgraded") else: _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config + cur, + user_version, + delta_files, + upgraded, + database_engine, + config, + data_stores=data_stores, ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, data_stores=data_stores) # check if any of our configured dynamic modules want a database if config is not None: @@ -84,9 +95,10 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, data_stores): """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. + applying any necessary deltas, including schemas from the given data + stores. The "full_schemas" directory has subdirectories named after versions. This function searches for the highest version less than or equal to @@ -111,52 +123,78 @@ def _setup_new_database(cur, database_engine): In the example foo.sql and bar.sql would be run, and then any delta files for versions strictly greater than 11. + + Note: we apply the full schemas and deltas from the top level `schema/` + folder as well those in the data stores specified. + + Args: + cur (Cursor): a database cursor + database_engine (DatabaseEngine) + data_stores (list[str]): The names of the data stores to instantiate + on the given database. """ current_dir = os.path.join(dir_path, "schema", "full_schemas") directory_entries = os.listdir(current_dir) - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - - if isinstance(database_engine, PostgresEngine): - specific = "postgres" - else: - specific = "sqlite" - - specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$") + # First we find the highest full schema version we have + valid_versions = [] for filename in directory_entries: - match = pattern.match(filename) or specific_pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.debug("Ignoring entry '%s' in 'full_schemas'", filename) + try: + ver = int(filename) + except ValueError: + continue + + if ver <= SCHEMA_VERSION: + valid_versions.append(ver) - if not valid_dirs: + if not valid_versions: raise PrepareDatabaseException( "Could not find a suitable base set of full schemas" ) - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + max_current_ver = max(valid_versions) logger.debug("Initialising schema v%d", max_current_ver) - directory_entries = os.listdir(sql_dir) + # Now lets find all the full schema files, both in the global schema and + # in data store schemas. + directories = [os.path.join(current_dir, str(max_current_ver))] + directories.extend( + os.path.join( + dir_path, + "data_stores", + data_store, + "schema", + "full_schemas", + str(max_current_ver), + ) + for data_store in data_stores + ) + + directory_entries = [] + for directory in directories: + directory_entries.extend( + _DirectoryListing(file_name, os.path.join(directory, file_name)) + for file_name in os.listdir(directory) + ) + + if isinstance(database_engine, PostgresEngine): + specific = "postgres" + else: + specific = "sqlite" - for filename in sorted( - fnmatch.filter(directory_entries, "*.sql") - + fnmatch.filter(directory_entries, "*.sql." + specific) - ): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) + directory_entries.sort() + for entry in directory_entries: + if entry.file_name.endswith(".sql") or entry.file_name.endswith( + ".sql." + specific + ): + logger.debug("Applying schema %s", entry.absolute_path) + executescript(cur, entry.absolute_path) cur.execute( database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)" + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" ), (max_current_ver, False), ) @@ -168,6 +206,7 @@ def _setup_new_database(cur, database_engine): upgraded=False, database_engine=database_engine, config=None, + data_stores=data_stores, is_empty=True, ) @@ -179,6 +218,7 @@ def _upgrade_existing_database( upgraded, database_engine, config, + data_stores, is_empty=False, ): """Upgrades an existing database. @@ -215,6 +255,10 @@ def _upgrade_existing_database( only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in some arbitrary order. + Note: we apply the delta files from the specified data stores as well as + those in the top-level schema. We apply all delta files across data stores + for a version before applying those in the next version. + Args: cur (Cursor) current_version (int): The current version of the schema. @@ -224,6 +268,14 @@ def _upgrade_existing_database( applied deltas or from full schema file. If `True` the function will never apply delta files for the given `current_version`, since the current_version wasn't generated by applying those delta files. + database_engine (DatabaseEngine) + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already + data_stores (list[str]): The names of the data stores to instantiate + on the given database. + is_empty (bool): Is this a blank database? I.e. do we need to run the + upgrade portions of the delta scripts. """ if current_version > SCHEMA_VERSION: @@ -248,24 +300,49 @@ def _upgrade_existing_database( for v in range(start_ver, SCHEMA_VERSION + 1): logger.info("Upgrading schema to v%d", v) - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + # We need to search both the global and per data store schema + # directories for schema updates. - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) + # First we find the directories to search in + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + directories = [delta_dir] + for data_store in data_stores: + directories.append( + os.path.join( + dir_path, "data_stores", data_store, "schema", "delta", str(v) + ) ) + # Now find which directories have anything of interest. + directory_entries = [] + for directory in directories: + logger.debug("Looking for schema deltas in %s", directory) + try: + file_names = os.listdir(directory) + directory_entries.extend( + _DirectoryListing(file_name, os.path.join(directory, file_name)) + for file_name in file_names + ) + except FileNotFoundError: + # Data stores can have empty entries for a given version delta. + pass + except OSError: + raise UpgradeDatabaseException( + "Could not open delta dir for version %d: %s" % (v, directory) + ) + + # We sort to ensure that we apply the delta files in a consistent + # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() - for file_name in directory_entries: + for entry in directory_entries: + file_name = entry.file_name relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) + absolute_path = entry.absolute_path + + logger.debug("Found file: %s (%s)", relative_path, absolute_path) if relative_path in applied_delta_files: continue - absolute_path = os.path.join(dir_path, "schema", "delta", relative_path) root_name, ext = os.path.splitext(file_name) if ext == ".py": # This is a python upgrade module. We need to import into some @@ -448,3 +525,16 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None + + +@attr.s() +class _DirectoryListing(object): + """Helper class to store schema file name and the + absolute path to it. + + These entries get sorted, so for consistency we want to ensure that + `file_name` attr is kept first. + """ + + file_name = attr.ib() + absolute_path = attr.ib() diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3a641f538b..18a462f0ee 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -15,12 +15,7 @@ from collections import namedtuple -from twisted.internet import defer - from synapse.api.constants import PresenceState -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList class UserPresenceState( @@ -72,132 +67,3 @@ class UserPresenceState( status_msg=None, currently_active=False, ) - - -class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( - len(presence_states) - ) - - with stream_ordering_manager as stream_orderings: - yield self.runInteraction( - "update_presence", - self._update_presence_txn, - stream_orderings, - presence_states, - ) - - return stream_orderings[-1], self._presence_id_gen.get_current_token() - - def _update_presence_txn(self, txn, stream_orderings, presence_states): - for stream_id, state in zip(stream_orderings, presence_states): - txn.call_after( - self.presence_stream_cache.entity_has_changed, state.user_id, stream_id - ) - txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) - - # Actually insert new rows - self._simple_insert_many_txn( - txn, - table="presence_stream", - values=[ - { - "stream_id": stream_id, - "user_id": state.user_id, - "state": state.state, - "last_active_ts": state.last_active_ts, - "last_federation_update_ts": state.last_federation_update_ts, - "last_user_sync_ts": state.last_user_sync_ts, - "status_msg": state.status_msg, - "currently_active": state.currently_active, - } - for state in presence_states - ], - ) - - # Delete old rows to stop database from getting really big - sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " - - for states in batch_iter(presence_states, 50): - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", [s.user_id for s in states] - ) - txn.execute(sql + clause, [stream_id] + list(args)) - - def get_all_presence_updates(self, last_id, current_id): - if last_id == current_id: - return defer.succeed([]) - - def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) - return txn.fetchall() - - return self.runInteraction( - "get_all_presence_updates", get_all_presence_updates_txn - ) - - @cached() - def _get_presence_for_user(self, user_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( - table="presence_stream", - column="user_id", - iterable=user_ids, - keyvalues={}, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - desc="get_presence_for_users", - ) - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return {row["user_id"]: UserPresenceState(**row) for row in rows} - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py new file mode 100644 index 0000000000..a368182034 --- /dev/null +++ b/synapse/storage/purge_events.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +class PurgeEventsStorage(object): + """High level interface for purging rooms and event history. + """ + + def __init__(self, hs, stores): + self.stores = stores + + @defer.inlineCallbacks + def purge_room(self, room_id: str): + """Deletes all record of a room + """ + + state_groups_to_delete = yield self.stores.main.purge_room(room_id) + yield self.stores.main.purge_room_state(room_id, state_groups_to_delete) + + @defer.inlineCallbacks + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + state_groups = yield self.stores.main.purge_history( + room_id, token, delete_local_events + ) + + logger.info("[purge] finding state groups that can be deleted") + + sg_to_delete = yield self._find_unreferenced_groups(state_groups) + + yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + + @defer.inlineCallbacks + def _find_unreferenced_groups(self, state_groups): + """Used when purging history to figure out which state groups can be + deleted. + + Args: + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + Deferred[set[int]] The set of state groups that can be deleted. + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + referenced = yield self.stores.main.get_referenced_state_groups( + current_search + ) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + edges = yield self.stores.main.get_previous_state_groups(current_search) + + prevs = set(edges.values()) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + graph.update(edges) + + to_delete = state_groups_seen - referenced_groups + + return to_delete diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index c4e24edff2..f47cec0d86 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,704 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.push.baserules import list_with_base_rules -from synapse.storage.appservice import ApplicationServiceWorkerStore -from synapse.storage.pusher import PusherWorkerStore -from synapse.storage.receipts import ReceiptsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -def _load_rules(rawrules, enabled_map): - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) - - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] - if rule_id in enabled_map: - if rule.get("enabled", True) != bool(enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule - - return rules - - -class PushRulesWorkerStore( - ApplicationServiceWorkerStore, - ReceiptsWorkerStore, - PusherWorkerStore, - RoomMemberWorkerStore, - SQLBaseStore, -): - """This is an abstract base class where subclasses must implement - `get_max_push_rules_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) - - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, - "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self.get_max_push_rules_stream_id(), - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - push_rules_id, - prefilled_cache=push_rules_prefill, - ) - - @abc.abstractmethod - def get_max_push_rules_stream_id(self): - """Get the position of the push rules stream. - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", - ), - desc="get_push_rules_enabled_for_user", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - - rules = _load_rules(rows, enabled_map) - - return rules - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", - ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) - - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules(self, user_ids): - if not user_ids: - return {} - - results = {user_id: [] for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=("*",), - desc="bulk_get_push_rules", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - for row in rows: - results.setdefault(row["user_name"], []).append(row) - - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) - - for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) - - return results - - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): - """Copy a single push rule from one room to another for a specific user. - - Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. - """ - # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) - new_rule_id = rule_id_scope + "/" + new_room_id - - # Change room id in each condition - for condition in rule.get("conditions", []): - if condition.get("key") == "room_id": - condition["pattern"] = new_room_id - - # Add the rule for the new room - yield self.add_push_rule( - user_id=user_id, - rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], - ) - - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): - """Copy all of the push rules from one room to another for a specific - user. - - Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. - """ - # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) - - # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) - if any( - (c.get("key") == "room_id" and c.get("pattern") == old_room_id) - for c in conditions - ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = set( - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - ) - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules_enabled(self, user_ids): - if not user_ids: - return {} - - results = {user_id: {} for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules_enable", - column="user_name", - iterable=user_ids, - retcols=("user_name", "rule_id", "enabled"), - desc="bulk_get_push_rules_enabled", - ) - for row in rows: - enabled = bool(row["enabled"]) - results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled - return results - - -class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( - self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, - ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - if before or after: - yield self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ) - else: - yield self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ) - - def _add_push_rule_relative_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - relative_to_rule = before or after - - res = self._simple_select_one_txn( - txn, - table="push_rules", - keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, - retcols=["priority_class", "priority"], - allow_none=True, - ) - - if not res: - raise RuleNotFoundException( - "before/after rule not found: %s" % (relative_to_rule,) - ) - - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] - - if base_priority_class != priority_class: - raise InconsistentRuleException( - "Given priority class does not match class of relative rule" - ) - - if before: - # Higher priority rules are executed first, So adding a rule before - # a rule means giving it a higher priority than that rule. - new_rule_priority = base_rule_priority + 1 - else: - # We increment the priority of the existing rules to make space for - # the new rule. Therefore if we want this rule to appear after - # an existing rule we give it the priority of the existing rule, - # and then increment the priority of the existing rule. - new_rule_priority = base_rule_priority - - sql = ( - "UPDATE push_rules SET priority = priority + 1" - " WHERE user_name = ? AND priority_class = ? AND priority >= ?" - ) - - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_rule_priority, - conditions_json, - actions_json, - ) - - def _add_push_rule_highest_priority_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - # find the highest priority rule in that class - sql = ( - "SELECT COUNT(*), MAX(priority) FROM push_rules" - " WHERE user_name = ? and priority_class = ?" - ) - txn.execute(sql, (user_id, priority_class)) - res = txn.fetchall() - (how_many, highest_prio) = res[0] - - new_prio = 0 - if how_many > 0: - new_prio = highest_prio + 1 - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_prio, - conditions_json, - actions_json, - ) - - def _upsert_push_rule_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id - using the _push_rule_id_gen if it needs to insert the rule. It assumes - that the "push_rules" table is locked""" - - sql = ( - "UPDATE push_rules" - " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" - " WHERE user_name = ? AND rule_id = ?" - ) - - txn.execute( - sql, - (priority_class, priority, conditions_json, actions_json, user_id, rule_id), - ) - - if txn.rowcount == 0: - # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() - - self._simple_insert_txn( - txn, - table="push_rules", - values={ - "id": push_rule_id, - "user_name": user_id, - "rule_id": rule_id, - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - if update_stream: - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ADD", - data={ - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): - """ - Delete a push rule. Args specify the row to be deleted and can be - any of the columns in the push_rule table, but below are the - standard ones - - Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted - """ - - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( - txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} - ) - - self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "delete_push_rule", - delete_push_rule_txn, - stream_id, - event_stream_ordering, - ) - - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - ) - - def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled - ): - new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( - txn, - "push_rules_enable", - {"user_name": user_id, "rule_id": rule_id}, - {"enabled": 1 if enabled else 0}, - {"id": new_id}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ENABLE" if enabled else "DISABLE", - ) - - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) - - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): - if is_default_rule: - # Add a dummy rule to the rules table with the user specified - # actions. - priority_class = -1 - priority = 1 - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - "[]", - actions_json, - update_stream=False, - ) - else: - self._simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ACTIONS", - data={"actions": actions_json}, - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "set_push_rule_actions", - set_push_rule_actions_txn, - stream_id, - event_stream_ordering, - ) - - def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): - values = { - "stream_id": stream_id, - "event_stream_ordering": event_stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": op, - } - if data is not None: - values.update(data) - - self._simple_insert_txn(txn, "push_rules_stream", values=values) - - txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) - txn.call_after( - self.push_rules_stream_cache.entity_has_changed, user_id, stream_id - ) - - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" - if last_id == current_id: - return defer.succeed([]) - - def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] - class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index fcb5f2f23a..d471ec9860 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,11 +17,7 @@ import logging import attr -from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore -from synapse.storage.stream import generate_pagination_where_clause -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -113,358 +109,3 @@ class AggregationPaginationToken(object): def as_tuple(self): return attr.astuple(self) - - -class RelationsWorkerStore(SQLBaseStore): - @cached(tree=True) - def get_relations_for_event( - self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of relations for an event, ordered by topological ordering. - - Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. - - Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. - """ - - where_clause = ["relates_to_id = ?"] - where_args = [event_id] - - if relation_type is not None: - where_clause.append("relation_type = ?") - where_args.append(relation_type) - - if event_type is not None: - where_clause.append("type = ?") - where_args.append(event_type) - - if aggregation_key: - where_clause.append("aggregation_key = ?") - where_args.append(aggregation_key) - - pagination_clause = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if pagination_clause: - where_clause.append(pagination_clause) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - sql = """ - SELECT event_id, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) - - def _get_recent_references_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - last_topo_id = None - last_stream_id = None - events = [] - for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] - - next_batch = None - if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_recent_references_for_event", _get_recent_references_for_event_txn - ) - - @cached(tree=True) - def get_aggregation_groups_for_event( - self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of annotations on the event, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - - - Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. - """ - - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] - - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE {where_clause} - GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} - LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) - - def _get_aggregation_groups_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn - ) - - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): - """Get the most recent edit (if any) that has happened for the given - event. - - Correctly handles checking whether edits were allowed to happen. - - Args: - event_id (str): The original event ID - - Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. - """ - - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. - - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ - - def _get_applicable_edit_txn(txn): - txn.execute(sql, (event_id, RelationTypes.REPLACE)) - row = txn.fetchone() - if row: - return row[0] - - edit_id = yield self.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn - ) - - if not edit_id: - return - - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event - - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): - """Check if a user has already annotated an event with the same key - (e.g. already liked an event). - - Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation - - Returns: - Deferred[bool] - """ - - sql = """ - SELECT 1 FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND relation_type = ? - AND type = ? - AND sender = ? - AND aggregation_key = ? - LIMIT 1; - """ - - def _get_if_user_has_annotated_event(txn): - txn.execute( - sql, - ( - parent_id, - RelationTypes.ANNOTATION, - event_type, - sender, - aggregation_key, - ), - ) - - return bool(txn.fetchone()) - - return self.runInteraction( - "get_if_user_has_annotated_event", _get_if_user_has_annotated_event - ) - - -class RelationsStore(RelationsWorkerStore): - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self._simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self._simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ff63487823..8c4a83a840 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,26 +17,6 @@ import logging from collections import namedtuple -from six import iteritems, itervalues - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import Sqlite3Engine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.types import get_domain_from_id -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii - logger = logging.getLogger(__name__) @@ -57,1102 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) # a given membership type, suitable for use in calculating heroes for a room. # "count" points to the total numberr of users of a given membership type. MemberSummary = namedtuple("MemberSummary", ("members", "count")) - -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" - - -class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) - - # Is the current_state_events.membership up to date? Or is the - # background update still running? - self._current_state_events_membership_up_to_date = False - - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, - ) - self._check_safe_current_state_events_membership_updated_txn(txn) - txn.close() - - if self.hs.config.metrics_flags.known_servers: - self._known_servers_count = 1 - self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, - ) - self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, - ) - LaterGauge( - "synapse_federation_known_servers", - "", - [], - lambda: self._known_servers_count, - ) - - @defer.inlineCallbacks - def _count_known_servers(self): - """ - Count the servers that this server knows about. - - The statistic is stored on the class for the - `synapse_federation_known_servers` LaterGauge to collect. - """ - - def _transact(txn): - if isinstance(self.database_engine, Sqlite3Engine): - query = """ - SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) - FROM ( - SELECT rm.user_id as user_id, instr(rm.user_id, ':') - AS pos FROM room_memberships as rm - INNER JOIN current_state_events as c ON rm.event_id = c.event_id - WHERE c.type = 'm.room.member' - ) as out - """ - else: - query = """ - SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) - FROM current_state_events - WHERE type = 'm.room.member' AND membership = 'join'; - """ - txn.execute(query) - return list(txn)[0][0] - - count = yield self.runInteraction("get_known_servers", _transact) - - # We always know about ourselves, even if we have nothing in - # room_memberships (for example, the server is new). - self._known_servers_count = max([count, 1]) - return self._known_servers_count - - def _check_safe_current_state_events_membership_updated_txn(self, txn): - """Checks if it is safe to assume the new current_state_events - membership column is up to date - """ - - pending_update = self._simple_select_one_txn( - txn, - table="background_updates", - keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, - retcols=["update_name"], - allow_none=True, - ) - - self._current_state_events_membership_up_to_date = not pending_update - - # If the update is still running, reschedule to run. - if pending_update: - self._clock.call_later( - 15.0, - run_as_background_process, - "_check_safe_current_state_events_membership_updated", - self.runInteraction, - "_check_safe_current_state_events_membership_updated", - self._check_safe_current_state_events_membership_updated_txn, - ) - - @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) - def get_hosts_in_room(self, room_id, cache_context): - """Returns the set of all hosts currently in the room - """ - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - return hosts - - @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.runInteraction( - "get_users_in_room", self.get_users_in_room_txn, room_id - ) - - def get_users_in_room_txn(self, txn, room_id): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ - - txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] - - @cached(max_entries=100000) - def get_room_summary(self, room_id): - """ Get the details of a room roughly suitable for use by the room - summary extension to /sync. Useful when lazy loading room members. - Args: - room_id (str): The room ID to query - Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. - """ - - def _get_room_summary_txn(txn): - # first get counts. - # We do this all in one transaction to keep the cache small. - # FIXME: get rid of this when we have room_stats - - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ - else: - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - res = {} - for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) - - # we order by membership and then fairly arbitrarily by event_id so - # heroes are consistent - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT state_key, membership, event_id - FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - ORDER BY - CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - event_id ASC - LIMIT ? - """ - else: - sql = """ - SELECT c.state_key, m.membership, c.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c USING (room_id, event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - c.event_id ASC - LIMIT ? - """ - - # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. - txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) - for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] - # we will always have a summary for this membership type at this - # point given the summary currently contains the counts. - members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) - - return res - - return self.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - - @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to - Args: - user_id (str): The user ID. - Returns: - A deferred list of RoomsForUser. - """ - - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) - - @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room - - Args: - user_id (str) - room_id (str) - - Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. - """ - invites = yield self.get_invited_rooms_for_user(user_id) - for invite in invites: - if invite.room_id == room_id: - return invite - return None - - @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user - matches one in the membership list. - - Filters out forgotten rooms. - - Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. - - Returns: - Deferred[list[RoomsForUser]] - """ - if not membership_list: - return defer.succeed(None) - - rooms = yield self.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, - user_id, - membership_list, - ) - - # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] - - def _get_rooms_for_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): - - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - else: - clause, args = make_in_list_sql_clause( - self.database_engine, "m.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] - - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.cursor_to_dict(txn) - ) - - return results - - @cachedInlineCallbacks(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to - - Args: - user_id (str) - - Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. - """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - ) - - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to - """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate - ) - return frozenset(r.room_id for r in rooms) - - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): - """Returns the set of users who share a room with `user_id` - """ - room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate - ) - - user_who_share_room = set() - for room_id in room_ids: - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - user_who_share_room.update(user_ids) - - return user_who_share_room - - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) - return result - - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) - ) - - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( - self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, - ): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - users_in_room = {} - member_event_ids = [ - e_id - for key, e_id in iteritems(current_state_ids) - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) - member_event_ids = [ - e_id - for key, e_id in iteritems(context.delta_ids) - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - users_in_room.pop(state_key, None) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) - - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry: - if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), - ) - else: - missing_member_event_ids.append(event_id) - - if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( - missing_member_event_ids - ) - users_in_room.update((row for row in event_to_memberships.values() if row)) - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), - ) - - return users_in_room - - @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, - ) - def _get_joined_profiles_from_event_ids(self, event_ids): - """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. - - Args: - event_ids (Iterable[str]): The member event IDs to lookup - - Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). - """ - - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_membership_from_event_ids", - ) - - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } - - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' - AND type = 'm.room.member' - AND c.room_id = ? - AND state_key LIKE ? - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) - ) - - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - cache = self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts - - @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - - count = yield self.runInteraction("did_forget_membership", f) - return count == 0 - - @cached() - def get_forgotten_rooms_for_user(self, user_id): - """Gets all rooms the user has forgotten. - - Args: - user_id (str) - - Returns: - Deferred[set[str]] - """ - - def _get_forgotten_rooms_for_user_txn(txn): - # This is a slightly convoluted query that first looks up all rooms - # that the user has forgotten in the past, then rechecks that list - # to see if any have subsequently been updated. This is done so that - # we can use a partial index on `forgotten = 1` on the assumption - # that few users will actually forget many rooms. - # - # Note that a room is considered "forgotten" if *all* membership - # events for that user and room have the forgotten field set (as - # when a user forgets a room we update all rows for that user and - # room, not just the current one). - sql = """ - SELECT room_id, ( - SELECT count(*) FROM room_memberships - WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 - ) AS count - FROM room_memberships AS m - WHERE user_id = ? AND forgotten = 1 - GROUP BY room_id, user_id; - """ - txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) - - return self.runInteraction( - "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn - ) - - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): - """Get all rooms that the user has ever been in. - - Args: - user_id (str) - - Returns: - Deferred[set[str]]: Set of room IDs. - """ - - room_ids = yield self._simple_select_onecol( - table="room_memberships", - keyvalues={"membership": Membership.JOIN, "user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_has_been_in", - ) - - return set(room_ids) - - -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - self.register_background_update_handler( - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - self._background_current_state_membership, - ) - self.register_background_index_update( - "room_membership_forgotten_idx", - index_name="room_memberships_user_room_forgotten", - table="room_memberships", - columns=["user_id", "room_id"], - where_clause="forgotten = 1", - ) - - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): - target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start - ) - max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 - ) - - INSERT_CLUMP_SIZE = 1000 - - def add_membership_profile_txn(txn): - sql = """ - SELECT stream_ordering, event_id, events.room_id, event_json.json - FROM events - INNER JOIN event_json USING (event_id) - INNER JOIN room_memberships USING (event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND type = 'm.room.member' - ORDER BY stream_ordering DESC - LIMIT ? - """ - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = self.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - to_update = [] - for row in rows: - event_id = row["event_id"] - room_id = row["room_id"] - try: - event_json = json.loads(row["json"]) - content = event_json["content"] - except Exception: - continue - - display_name = content.get("displayname", None) - avatar_url = content.get("avatar_url", None) - - if display_name or avatar_url: - to_update.append((display_name, avatar_url, event_id, room_id)) - - to_update_sql = """ - UPDATE room_memberships SET display_name = ?, avatar_url = ? - WHERE event_id = ? AND room_id = ? - """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self._background_update_progress_txn( - txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn - ) - - if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): - """Update the new membership column on current_state_events. - - This works by iterating over all rooms in alphebetical order. - """ - - def _background_current_state_membership_txn(txn, last_processed_room): - processed = 0 - while processed < batch_size: - txn.execute( - """ - SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? - """, - (last_processed_room,), - ) - row = txn.fetchone() - if not row or not row[0]: - return processed, True - - next_room, = row - - sql = """ - UPDATE current_state_events - SET membership = ( - SELECT membership FROM room_memberships - WHERE event_id = current_state_events.event_id - ) - WHERE room_id = ? - """ - txn.execute(sql, (next_room,)) - processed += txn.rowcount - - last_processed_room = next_room - - self._background_update_progress_txn( - txn, - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - {"last_processed_room": last_processed_room}, - ) - - return processed, False - - # If we haven't got a last processed room then just use the empty - # string, which will compare before all room IDs correctly. - last_processed_room = progress.get("last_processed_room", "") - - row_count, finished = yield self.runInteraction( - "_background_current_state_membership_update", - _background_current_state_membership_txn, - last_processed_room, - ) - - if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) - - return row_count - - -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.get_forgotten_rooms_for_user, (user_id,) - ) - - return self.runInteraction("forget_membership", f) - - -class _JoinedHostsCache(object): - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - @defer.inlineCallbacks - def get_destinations(self, state_entry): - """Get set of destinations for a state entry - - Args: - state_entry(synapse.state._StateCacheEntry) - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (yield self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in iteritems(state_entry.delta_ids): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = yield self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) - return frozenset(self.hosts_to_joined_users) - - def __len__(self): - return self._len diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql new file mode 100644 index 0000000000..c2d2a4f836 --- /dev/null +++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE background_updates ADD COLUMN depends_on TEXT; diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..1005880466 --- /dev/null +++ b/synapse/storage/schema/full_schemas/54/full.sql @@ -0,0 +1,8 @@ + + +CREATE TABLE background_updates ( + update_name text NOT NULL, + progress_json text NOT NULL, + depends_on text, + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a941a5ae3f..3735846899 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,45 +14,18 @@ # limitations under the License. import logging -from collections import namedtuple from six import iteritems, itervalues -from six.moves import range import attr from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError -from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.util.caches import get_cache_factor_for, intern_string -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) -MAX_STATE_DELTA_HOPS = 100 - - -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - @attr.s(slots=True) class StateFilter(object): """A filter used when querying for state. @@ -353,402 +326,23 @@ class StateFilter(object): return member_filter, non_member_filter -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - -# this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): - """The parts of StateGroupStore that can be called from workers. +class StateGroupStorage(object): + """High level interface to fetching state for event. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) - - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - - @defer.inlineCallbacks - def get_room_version(self, room_id): - """Get the room_version of a given room - - Args: - room_id (str) - - Returns: - Deferred[str] - - Raises: - NotFoundError if the room is unknown - """ - # for now we do this by looking at the create event. We may want to cache this - # more intelligently in future. - - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - return create_event.content.get("room_version", "1") - - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. - Otherwise return None. - - Args: - room_id (str) - - Returns: - Deferred[unicode|None]: predecessor room id - - Raises: - NotFoundError if the room is unknown - """ - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - - # Return predecessor if present - return create_event.content.get("predecessor", None) - - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): - """Get the create state event for a room. - - Args: - room_id (str) - - Returns: - Deferred[EventBase]: The room creation event. - - Raises: - NotFoundError if the room is unknown - """ - state_ids = yield self.get_current_state_ids(room_id) - create_id = state_ids.get((EventTypes.Create, "")) - - # If we can't find the create event, assume we've hit a dead end - if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) - - # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) - return create_event - - @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): - """Get the current state event ids for a room based on the - current_state_events table. - - Args: - room_id (str) - - Returns: - deferred: dict of (type, state_key) -> event_id - """ - - def _get_current_state_ids_txn(txn): - txn.execute( - """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } - - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) - - # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): - """Get the current state event of a given type for a room based on the - current_state_events table. This may not be as up-to-date as the result - of doing a fresh state resolution as per state_handler.get_current_state - - Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. - """ - - where_clause, where_args = state_filter.make_sql_filter_clause() - - if not where_clause: - # We delegate to the cached version - return self.get_current_state_ids(room_id) - - def _get_filtered_current_state_ids_txn(txn): - results = {} - sql = """ - SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """ - - if where_clause: - sql += " AND (%s)" % (where_clause,) - - args = [room_id] - args.extend(where_args) - txn.execute(sql, args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id - - return results - - return self.runInteraction( - "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn - ) - - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): - """Get canonical alias for room, if any - - Args: - room_id (str) - - Returns: - Deferred[str|None]: The canonical alias, if any - """ - - state = yield self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return - - event = yield self.get_event(event_id, allow_none=True) - if not event: - return - - return event.content.get("canonical_alias") + def __init__(self, hs, stores): + self.stores = stores - @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. Returns: - (prev_group, delta_ids), where both may be None. + Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + (prev_group, delta_ids) """ - def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self._simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.stores.main.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -765,10 +359,10 @@ class StateGroupWorkerStore( if not event_ids: return {} - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) + group_to_state = yield self.stores.main._get_state_for_groups(groups) return group_to_state @@ -789,7 +383,6 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - Returns: Deferred[dict[int, list[EventBase]]]: dict of state_group_id -> list of state events. @@ -799,7 +392,7 @@ class StateGroupWorkerStore( group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.get_events( + state_event_map = yield self.stores.main.get_events( [ ev_id for group_ids in itervalues(group_to_ids) @@ -817,7 +410,6 @@ class StateGroupWorkerStore( for group, event_id_map in iteritems(group_to_ids) } - @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. @@ -830,39 +422,28 @@ class StateGroupWorkerStore( Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - return results + return self.stores.main._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state dicts for each event. - Args: event_ids (list[string]) state_filter (StateFilter): The state filter used to fetch state from the database. - Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) - state_event_map = yield self.get_events( + state_event_map = yield self.stores.main.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], get_prev_content=False, ) @@ -892,10 +473,12 @@ class StateGroupWorkerStore( Returns: A deferred dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) event_to_state = { event_id: group_to_state[group] @@ -936,76 +519,6 @@ class StateGroupWorkerStore( state_map = yield self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] - @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - desc="_get_state_group_for_event", - ) - - @cachedList( - cached_method_name="_get_state_group_for_event", - list_name="event_ids", - num_args=1, - inlineCallbacks=True, - ) - def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ - rows = yield self._simple_select_many_batch( - table="event_to_state_groups", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id", "state_group"), - desc="_get_state_group_for_events", - ) - - return {row["event_id"]: row["state_group"] for row in rows} - - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -1019,154 +532,7 @@ class StateGroupWorkerStore( Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - ) - - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) + return self.stores.main._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -1186,360 +552,6 @@ class StateGroupWorkerStore( Returns: Deferred[int]: The state group ID """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self._simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.runInteraction("store_state_group", _store_state_group_txn) - - -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, + return self.stores.main.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - self.register_background_index_update( - self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, - index_name="event_to_state_groups_sg_index", - table="event_to_state_groups", - columns=["state_group"], - ) - - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self._execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - prev_group, = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self._simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self._simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - - def _store_event_state_mappings_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index cbb0a4810a..9d851beaa5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - val, = cur.fetchone() + (val,) = cur.fetchone() cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 804dbca443..5c4de2e69f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,12 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers @@ -138,7 +139,7 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. + func (func): Function to execute, should return a deferred or coroutine. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. @@ -148,11 +149,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass @@ -309,7 +309,7 @@ class Linearizer(object): ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", self.name, key, diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 43fd65d693..da5077b471 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None): if collect_callback: collect_callback() except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) + logger.warning("Error calculating metrics for %s: %s", cache_name, e) raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6a..84f5ae22c3 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,8 +17,8 @@ import functools import inspect import logging import threading -from collections import namedtuple -from typing import Any, cast +from typing import Any, Tuple, Union, cast +from weakref import WeakValueDictionary from six import itervalues @@ -38,6 +38,8 @@ from . import register_cache logger = logging.getLogger(__name__) +CacheKey = Union[Tuple, Any] + class _CachedFunction(Protocol): invalidate = None # type: Any @@ -430,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase): # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: - kwargs["cache_context"] = _CacheContext(cache, cache_key) + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -438,7 +440,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -482,9 +484,8 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ def __init__( @@ -618,21 +619,45 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped return wrapped -class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. - def invalidate(self): - self.cache.invalidate(self.key) +class _CacheContext: + """Holds cache information from the cached function higher in the calling order. + + Can be used to invalidate the higher level cache entry if something changes + on a lower level. + """ + + _cache_context_objects = ( + WeakValueDictionary() + ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + + def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + self._cache = cache + self._cache_key = cache_key + + def invalidate(self): # type: () -> None + """Invalidates the cache entry referred to by the context.""" + self._cache.invalidate(self._cache_key) + + @classmethod + def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + """Returns an instance constructed with the given arguments. + + A new instance is only created if none already exists. + """ + + # We make sure there are no identical _CacheContext instances. This is + # important in particular to dedupe when we add callbacks to lru cache + # nodes, otherwise the number of callbacks would grow. + return cls._cache_context_objects.setdefault( + (cache, cache_key), cls(cache, cache_key) + ) def cached( diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b1bcdf23c..3286804322 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -119,7 +119,7 @@ class Measure(object): context = LoggingContext.current_context() if context != self.start_context: - logger.warn( + logger.warning( "Context has unexpectedly changed from '%s' to '%s'. (%r)", self.start_context, context, @@ -128,7 +128,7 @@ class Measure(object): return if not context: - logger.warn("Expected context. (%r)", self.name) + logger.warning("Expected context. (%r)", self.name) return current = context.get_resource_usage() @@ -140,7 +140,7 @@ class Measure(object): block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( + logger.warning( "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index fa404b9d75..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -42,6 +42,7 @@ def get_version_string(module): try: null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: git_branch = ( subprocess.check_output( @@ -51,7 +52,8 @@ def get_version_string(module): .decode("ascii") ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: @@ -63,7 +65,7 @@ def get_version_string(module): .decode("ascii") ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: @@ -74,7 +76,7 @@ def get_version_string(module): .strip() .decode("ascii") ) - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: @@ -89,7 +91,7 @@ def get_version_string(module): ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..8c843febd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() ): """ Check which events a user is allowed to see Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -68,12 +68,12 @@ def filter_events_for_client( events = list(e for e in events if not e.internal_metadata.is_soft_failed()) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +84,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) def allowed(event): """ @@ -213,13 +213,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +278,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +296,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +332,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +362,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) |