diff options
Diffstat (limited to 'synapse')
157 files changed, 5294 insertions, 4648 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e8112d5f05..0c6c93a87b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -64,6 +64,8 @@ class Auth(object): self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) + self._account_validity = hs.config.account_validity + @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -226,6 +228,17 @@ class Auth(object): token_id = user_info["token_id"] is_guest = user_info["is_guest"] + # Deny the request if the user account has expired. + if self._account_validity.enabled: + user_id = user.to_string() + expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + if expiration_ts is not None and self.clock.time_msec() >= expiration_ts: + raise AuthError( + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, + ) + # device_id may not be present if get_user_by_access_token has been # stubbed out. device_id = user_info.get("device_id") @@ -543,7 +556,7 @@ class Auth(object): """ Check if the given user is a local server admin. Args: - user (str): mxid of user to check + user (UserID): user to check Returns: bool: True if the user is an admin diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f47c33a074..0860b75905 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,6 +69,7 @@ class EventTypes(object): Redaction = "m.room.redaction" ThirdPartyInvite = "m.room.third_party_invite" Encryption = "m.room.encryption" + RelatedGroups = "m.room.related_groups" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" @@ -102,46 +103,6 @@ class ThirdPartyEntityKind(object): LOCATION = "location" -class RoomVersions(object): - V1 = "1" - V2 = "2" - V3 = "3" - STATE_V2_TEST = "state-v2-test" - - -class RoomDisposition(object): - STABLE = "stable" - UNSTABLE = "unstable" - - -# the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = RoomVersions.V1 - -# vdh-test-version is a placeholder to get room versioning support working and tested -# until we have a working v2. -KNOWN_ROOM_VERSIONS = { - RoomVersions.V1, - RoomVersions.V2, - RoomVersions.V3, - RoomVersions.STATE_V2_TEST, - RoomVersions.V3, -} - - -class EventFormatVersions(object): - """This is an internal enum for tracking the version of the event format, - independently from the room version. - """ - V1 = 1 - V2 = 2 - - -KNOWN_EVENT_FORMAT_VERSIONS = { - EventFormatVersions.V1, - EventFormatVersions.V2, -} - - ServerNoticeMsgType = "m.server_notice" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0b464834ce..ff89259dec 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ class Codes(object): UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" + EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" class CodeMessageException(RuntimeError): diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py new file mode 100644 index 0000000000..e77abe1040 --- /dev/null +++ b/synapse/api/room_versions.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import attr + + +class EventFormatVersions(object): + """This is an internal enum for tracking the version of the event format, + independently from the room version. + """ + V1 = 1 # $id:server format + V2 = 2 # MSC1659-style $hash format: introduced for room v3 + + +KNOWN_EVENT_FORMAT_VERSIONS = { + EventFormatVersions.V1, + EventFormatVersions.V2, +} + + +class StateResolutionVersions(object): + """Enum to identify the state resolution algorithms""" + V1 = 1 # room v1 state res + V2 = 2 # MSC1442 state res: room v2 and later + + +class RoomDisposition(object): + STABLE = "stable" + UNSTABLE = "unstable" + + +@attr.s(slots=True, frozen=True) +class RoomVersion(object): + """An object which describes the unique attributes of a room version.""" + + identifier = attr.ib() # str; the identifier for this version + disposition = attr.ib() # str; one of the RoomDispositions + event_format = attr.ib() # int; one of the EventFormatVersions + state_res = attr.ib() # int; one of the StateResolutionVersions + + +class RoomVersions(object): + V1 = RoomVersion( + "1", + RoomDisposition.STABLE, + EventFormatVersions.V1, + StateResolutionVersions.V1, + ) + STATE_V2_TEST = RoomVersion( + "state-v2-test", + RoomDisposition.UNSTABLE, + EventFormatVersions.V1, + StateResolutionVersions.V2, + ) + V2 = RoomVersion( + "2", + RoomDisposition.STABLE, + EventFormatVersions.V1, + StateResolutionVersions.V2, + ) + V3 = RoomVersion( + "3", + RoomDisposition.STABLE, + EventFormatVersions.V2, + StateResolutionVersions.V2, + ) + + +# the version we will give rooms which are created on this server +DEFAULT_ROOM_VERSION = RoomVersions.V1 + + +KNOWN_ROOM_VERSIONS = { + v.identifier: v for v in ( + RoomVersions.V1, + RoomVersions.V2, + RoomVersions.V3, + RoomVersions.STATE_V2_TEST, + ) +} # type: dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 8102176653..cb71d80875 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index beaea64a61..864f1eac48 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -45,6 +45,7 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.login import LoginRestServlet +from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.room import ( JoinedRoomMemberListRestServlet, PublicRoomListRestServlet, @@ -52,9 +53,11 @@ from synapse.rest.client.v1.room import ( RoomMemberListRestServlet, RoomStateRestServlet, ) +from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree @@ -109,12 +112,12 @@ class ClientReaderServer(HomeServer): ThreepidRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) KeyChangesServlet(self).register(resource) + VoipRestServlet(self).register(resource) + PushRuleRestServlet(self).register(resource) + VersionsRestServlet().register(resource) resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, + "/_matrix/client": resource, }) root_resource = create_resource_tree(resources, NoResource()) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 9711a7147c..1d43f2b075 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -38,7 +38,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams import ReceiptsStream +from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.types import ReadReceipt diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 869c028d1f..1045d28949 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource +from synapse.rest.admin import AdminRestResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource @@ -180,6 +181,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), + "/_synapse/admin": AdminRestResource(self), }) if self.get_config().saml2_enabled: @@ -518,6 +520,7 @@ def run(hs): uptime = 0 stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context stats["timestamp"] = now stats["uptime_seconds"] = uptime version = sys.version_info @@ -558,7 +561,6 @@ def run(hs): stats["database_engine"] = hs.get_datastore().database_engine_name stats["database_server_version"] = hs.get_datastore().get_server_version() - logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: yield hs.get_simple_http_client().put_json( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 9163b56d86..5388def28a 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -48,6 +48,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.streams.events import EventsStreamEventRow from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet @@ -369,7 +370,9 @@ class SyncReplicationHandler(ReplicationClientHandler): # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. for row in rows: - event = yield self.store.get_event(row.event_id) + if row.type != EventsStreamEventRow.TypeId: + continue + event = yield self.store.get_event(row.data.event_id) extra_users = () if event.type == EventTypes.Member: extra_users = (event.state_key,) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index d1ab9512cd..355f5aa71d 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -36,6 +36,10 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamCurrentStateRow, +) from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -73,19 +77,18 @@ class UserDirectorySlaveStore( prefilled_cache=curr_state_delta_prefill, ) - self._current_state_delta_pos = events_max - def stream_positions(self): result = super(UserDirectorySlaveStore, self).stream_positions() - result["current_state_deltas"] = self._current_state_delta_pos return result def process_replication_rows(self, stream_name, token, rows): - if stream_name == "current_state_deltas": - self._current_state_delta_pos = token + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(token) for row in rows: + if row.type != EventsStreamCurrentStateRow.TypeId: + continue self._curr_state_delta_stream_cache.entity_has_changed( - row.room_id, token + row.data.room_id, token ) return super(UserDirectorySlaveStore, self).process_replication_rows( stream_name, token, rows @@ -170,7 +173,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): yield super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) - if stream_name == "current_state_deltas": + if stream_name == EventsStream.NAME: run_in_background(self._notify_directory) @defer.inlineCallbacks diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 93d70cff14..342a6ce5fd 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -71,6 +71,12 @@ class EmailConfig(Config): self.email_notif_from = email_config["notif_from"] self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_text = email_config["notif_template_text"] + self.email_expiry_template_html = email_config.get( + "expiry_template_html", "notice_expiry.html", + ) + self.email_expiry_template_text = email_config.get( + "expiry_template_text", "notice_expiry.txt", + ) template_dir = email_config.get("template_dir") # we need an absolute path, because we change directory after starting (and @@ -120,7 +126,7 @@ class EmailConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): return """ - # Enable sending emails for notification events + # Enable sending emails for notification events or expiry notices # Defining a custom URL for Riot is only needed if email notifications # should contain links to a self-hosted installation of Riot; when set # the "app_name" setting is ignored. @@ -142,6 +148,9 @@ class EmailConfig(Config): # #template_dir: res/templates # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt + # # Templates for account expiry notices. + # expiry_template_html: notice_expiry.html + # expiry_template_text: notice_expiry.txt # notif_for_new_users: True # riot_base_url: "http://localhost/riot" """ diff --git a/synapse/config/key.py b/synapse/config/key.py index 933928885a..eb10259818 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -42,7 +42,8 @@ class KeyConfig(Config): if "signing_key" in config: self.signing_key = read_signing_keys([config["signing_key"]]) else: - self.signing_key = self.read_signing_key(config["signing_key_path"]) + self.signing_key_path = config["signing_key_path"] + self.signing_key = self.read_signing_key(self.signing_key_path) self.old_signing_keys = self.read_old_signing_keys( config.get("old_signing_keys", {}) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f6b2b9ceee..1309bce3ee 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -20,6 +20,29 @@ from synapse.types import RoomAlias from synapse.util.stringutils import random_string_with_symbols +class AccountValidityConfig(Config): + def __init__(self, config, synapse_config): + self.enabled = config.get("enabled", False) + self.renew_by_email_enabled = ("renew_at" in config) + + if self.enabled: + if "period" in config: + self.period = self.parse_duration(config["period"]) + else: + raise ConfigError("'period' is required when using account validity") + + if "renew_at" in config: + self.renew_at = self.parse_duration(config["renew_at"]) + + if "renew_email_subject" in config: + self.renew_email_subject = config["renew_email_subject"] + else: + self.renew_email_subject = "Renew your %(app)s account" + + if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + class RegistrationConfig(Config): def read_config(self, config): @@ -31,8 +54,13 @@ class RegistrationConfig(Config): strtobool(str(config["disable_registration"])) ) + self.account_validity = AccountValidityConfig( + config.get("account_validity", {}), config, + ) + self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) + self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -75,6 +103,32 @@ class RegistrationConfig(Config): # #enable_registration: false + # Optional account validity configuration. This allows for accounts to be denied + # any request after a given period. + # + # ``enabled`` defines whether the account validity feature is enabled. Defaults + # to False. + # + # ``period`` allows setting the period after which an account is valid + # after its registration. When renewing the account, its validity period + # will be extended by this amount of time. This parameter is required when using + # the account validity feature. + # + # ``renew_at`` is the amount of time before an account's expiry date at which + # Synapse will send an email to the account's email address with a renewal link. + # This needs the ``email`` and ``public_baseurl`` configuration sections to be + # filled. + # + # ``renew_email_subject`` is the subject of the email sent out with the renewal + # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter + # from the ``email`` section. + # + #account_validity: + # enabled: True + # period: 6w + # renew_at: 1w + # renew_email_subject: "Renew your %%(app)s account" + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: @@ -97,6 +151,10 @@ class RegistrationConfig(Config): # - medium: msisdn # pattern: '\\+44' + # Enable 3PIDs lookup requests to identity servers from this server. + # + #enable_3pid_lookup: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 39b9eb29c2..aa6eac271f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/config/server.py b/synapse/config/server.py index 08e4e45482..147a976485 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -37,6 +37,7 @@ class ServerConfig(Config): def read_config(self, config): self.server_name = config["server_name"] + self.server_context = config.get("server_context", None) try: parse_and_validate_server_name(self.server_name) @@ -113,11 +114,13 @@ class ServerConfig(Config): # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( - "federation_domain_whitelist", None + "federation_domain_whitelist", None, ) - # turn the whitelist into a hash for speed of lookup + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup self.federation_domain_whitelist = {} + for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True @@ -131,6 +134,12 @@ class ServerConfig(Config): # sending out any replication updates. self.replication_torture_level = config.get("replication_torture_level") + # Whether to require a user to be in the room to add an alias to it. + # Defaults to True. + self.require_membership_for_aliases = config.get( + "require_membership_for_aliases", True, + ) + self.listeners = [] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -385,8 +394,8 @@ class ServerConfig(Config): # # Valid resource names are: # - # client: the client-server API (/_matrix/client). Also implies 'media' and - # 'static'. + # client: the client-server API (/_matrix/client), and the synapse admin + # API (/_synapse/admin). Also implies 'media' and 'static'. # # consent: user consent forms (/_matrix/consent). See # docs/consent_tracking.md. @@ -484,6 +493,14 @@ class ServerConfig(Config): #mau_limit_reserved_threepids: # - medium: 'email' # address: 'reserved_user@example.com' + + # Used by phonehome stats to group together related servers. + #server_context: context + + # Whether to require a user to be in the room to add an alias to it. + # Defaults to 'true'. + # + #require_membership_for_aliases: false """ % locals() def read_arguments(self, args): diff --git a/synapse/config/tls.py b/synapse/config/tls.py index f0014902da..72dd5926f9 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -24,8 +24,10 @@ import six from unpaddedbase64 import encode_base64 from OpenSSL import crypto +from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError +from synapse.util import glob_to_regex logger = logging.getLogger(__name__) @@ -70,6 +72,53 @@ class TlsConfig(Config): self.tls_fingerprints = list(self._original_tls_fingerprints) + # Whether to verify certificates on outbound federation traffic + self.federation_verify_certificates = config.get( + "federation_verify_certificates", False, + ) + + # Whitelist of domains to not verify certificates for + fed_whitelist_entries = config.get( + "federation_certificate_verification_whitelist", [], + ) + + # Support globs (*) in whitelist values + self.federation_certificate_verification_whitelist = [] + for entry in fed_whitelist_entries: + # Convert globs to regex + entry_regex = glob_to_regex(entry) + self.federation_certificate_verification_whitelist.append(entry_regex) + + # List of custom certificate authorities for federation traffic validation + custom_ca_list = config.get( + "federation_custom_ca_list", None, + ) + + # Read in and parse custom CA certificates + self.federation_ca_trust_root = None + if custom_ca_list is not None: + if len(custom_ca_list) == 0: + # A trustroot cannot be generated without any CA certificates. + # Raise an error if this option has been specified without any + # corresponding certificates. + raise ConfigError("federation_custom_ca_list specified without " + "any certificate files") + + certs = [] + for ca_file in custom_ca_list: + logger.debug("Reading custom CA certificate file: %s", ca_file) + content = self.read_file(ca_file) + + # Parse the CA certificates + try: + cert_base = Certificate.loadPEM(content) + certs.append(cert_base) + except Exception as e: + raise ConfigError("Error parsing custom CA certificate file %s: %s" + % (ca_file, e)) + + self.federation_ca_trust_root = trustRootFromCertificates(certs) + # This config option applies to non-federation HTTP clients # (e.g. for talking to recaptcha, identity servers, and such) # It should never be used in production, and is intended for @@ -99,15 +148,15 @@ class TlsConfig(Config): try: with open(self.tls_certificate_file, 'rb') as f: cert_pem = f.read() - except Exception: - logger.exception("Failed to read existing certificate off disk!") - raise + except Exception as e: + raise ConfigError("Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e)) try: tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) - except Exception: - logger.exception("Failed to parse existing certificate off disk!") - raise + except Exception as e: + raise ConfigError("Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e)) if not allow_self_signed: if tls_certificate.get_subject() == tls_certificate.get_issuer(): @@ -192,6 +241,40 @@ class TlsConfig(Config): # #tls_private_key_path: "%(tls_private_key_path)s" + # Whether to verify TLS certificates when sending federation traffic. + # + # This currently defaults to `false`, however this will change in + # Synapse 1.0 when valid federation certificates will be required. + # + #federation_verify_certificates: true + + # Skip federation certificate verification on the following whitelist + # of domains. + # + # This setting should only be used in very specific cases, such as + # federation over Tor hidden services and similar. For private networks + # of homeservers, you likely want to use a private CA instead. + # + # Only effective if federation_verify_certicates is `true`. + # + #federation_certificate_verification_whitelist: + # - lon.example.com + # - *.domain.com + # - *.onion + + # List of custom certificate authorities for federation traffic. + # + # This setting should only normally be used within a private network of + # homeservers. + # + # Note that this list will replace those that are provided by your + # operating environment. Certificates must be in PEM format. + # + #federation_custom_ca_list: + # - myCA1.pem + # - myCA2.pem + # - myCA3.pem + # ACME support: This will configure Synapse to request a valid TLS certificate # for your configured `server_name` via Let's Encrypt. # diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 49cbc7098f..59ea087e66 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -18,10 +18,10 @@ import logging from zope.interface import implementer from OpenSSL import SSL, crypto -from twisted.internet._sslverify import _defaultCurveName +from twisted.internet._sslverify import ClientTLSOptions, _defaultCurveName from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator -from twisted.internet.ssl import CertificateOptions, ContextFactory +from twisted.internet.ssl import CertificateOptions, ContextFactory, platformTrust from twisted.python.failure import Failure logger = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def _tolerateErrors(wrapped): @implementer(IOpenSSLClientConnectionCreator) -class ClientTLSOptions(object): +class ClientTLSOptionsNoVerify(object): """ Client creator for TLS without certificate identity verification. This is a copy of twisted.internet._sslverify.ClientTLSOptions with the identity @@ -127,9 +127,30 @@ class ClientTLSOptionsFactory(object): to remote servers for federation.""" def __init__(self, config): - # We don't use config options yet - self._options = CertificateOptions(verify=False) + self._config = config + self._options_noverify = CertificateOptions() + + # Check if we're using a custom list of a CA certificates + trust_root = config.federation_ca_trust_root + if trust_root is None: + # Use CA root certs provided by OpenSSL + trust_root = platformTrust() + + self._options_verify = CertificateOptions(trustRoot=trust_root) def get_options(self, host): # Use _makeContext so that we get a fresh OpenSSL CTX each time. - return ClientTLSOptions(host, self._options._makeContext()) + + # Check if certificate verification has been enabled + should_verify = self._config.federation_verify_certificates + + # Check if we've disabled certificate verification for this host + if should_verify: + for regex in self._config.federation_certificate_verification_whitelist: + if regex.match(host): + should_verify = False + break + + if should_verify: + return ClientTLSOptions(host, self._options_verify._makeContext()) + return ClientTLSOptionsNoVerify(host, self._options_noverify._makeContext()) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 0207cd989a..d8ba870cca 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017, 2018 New Vector Ltd. +# Copyright 2017, 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from collections import namedtuple from six import raise_from from six.moves import urllib +import nacl.signing from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -113,40 +114,54 @@ class Keyring(object): server_name. The deferreds run their callbacks in the sentinel logcontext. """ + # a list of VerifyKeyRequests verify_requests = [] + handle = preserve_fn(_handle_key_deferred) - for server_name, json_object in server_and_json: + def process(server_name, json_object): + """Process an entry in the request list + Given a (server_name, json_object) pair from the request list, + adds a key request to verify_requests, and returns a deferred which will + complete or fail (in the sentinel context) when verification completes. + """ key_ids = signature_ids(json_object, server_name) + if not key_ids: - logger.warn("Request from %s: no supported signature keys", - server_name) - deferred = defer.fail(SynapseError( - 400, - "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, - )) - else: - deferred = defer.Deferred() + return defer.fail( + SynapseError( + 400, + "Not signed by %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + ) logger.debug("Verifying for %s with key_ids %s", server_name, key_ids) + # add the key request to the queue, but don't start it off yet. verify_request = VerifyKeyRequest( - server_name, key_ids, json_object, deferred + server_name, key_ids, json_object, defer.Deferred(), ) - verify_requests.append(verify_request) - run_in_background(self._start_key_lookups, verify_requests) + # now run _handle_key_deferred, which will wait for the key request + # to complete and then do the verification. + # + # We want _handle_key_request to log to the right context, so we + # wrap it with preserve_fn (aka run_in_background) + return handle(verify_request) - # Pass those keys to handle_key_deferred so that the json object - # signatures can be verified - handle = preserve_fn(_handle_key_deferred) - return [ - handle(rq) for rq in verify_requests + results = [ + process(server_name, json_object) + for server_name, json_object in server_and_json ] + if verify_requests: + run_in_background(self._start_key_lookups, verify_requests) + + return results + @defer.inlineCallbacks def _start_key_lookups(self, verify_requests): """Sets off the key fetches for each verify request @@ -274,10 +289,6 @@ class Keyring(object): @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): - # dict[str, dict[str, VerifyKey]]: results so far. - # map server_name -> key_id -> VerifyKey - merged_results = {} - # dict[str, set(str)]: keys to fetch for each server missing_keys = {} for verify_request in verify_requests: @@ -287,29 +298,29 @@ class Keyring(object): for fn in key_fetch_fns: results = yield fn(missing_keys.items()) - merged_results.update(results) # We now need to figure out which verify requests we have keys # for and which we don't missing_keys = {} requests_missing_keys = [] for verify_request in verify_requests: - server_name = verify_request.server_name - result_keys = merged_results[server_name] - if verify_request.deferred.called: # We've already called this deferred, which probably # means that we've already found a key for it. continue + server_name = verify_request.server_name + + # see if any of the keys we got this time are sufficient to + # complete this VerifyKeyRequest. + result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - if key_id in result_keys: + key = result_keys.get(key_id) + if key: with PreserveLoggingContext(): - verify_request.deferred.callback(( - server_name, - key_id, - result_keys[key_id], - )) + verify_request.deferred.callback( + (server_name, key_id, key) + ) break else: # The else block is only reached if the loop above @@ -343,27 +354,24 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): """ - Args: - server_name_and_key_ids (list[(str, iterable[str])]): + server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): list of (server_name, iterable[key_id]) tuples to fetch keys for Returns: - Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from + Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from server_name -> key_id -> VerifyKey """ - res = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.get_server_verify_keys, - server_name, key_ids, - ).addCallback(lambda ks, server: (server, ks), server_name) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) - - defer.returnValue(dict(res)) + keys_to_fetch = ( + (server_name, key_id) + for server_name, key_ids in server_name_and_key_ids + for key_id in key_ids + ) + res = yield self.store.get_server_verify_keys(keys_to_fetch) + keys = {} + for (server_name, key_id), key in res.items(): + keys.setdefault(server_name, {})[key_id] = key + defer.returnValue(keys) @defer.inlineCallbacks def get_keys_from_perspectives(self, server_name_and_key_ids): @@ -494,11 +502,11 @@ class Keyring(object): ) processed_response = yield self.process_v2_response( - perspective_name, response, only_from_server=False + perspective_name, response ) + server_name = response["server_name"] - for server_name, response_keys in processed_response.items(): - keys.setdefault(server_name, {}).update(response_keys) + keys.setdefault(server_name, {}).update(processed_response) yield logcontext.make_deferred_yieldable(defer.gatherResults( [ @@ -517,7 +525,7 @@ class Keyring(object): @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} + keys = {} # type: dict[str, nacl.signing.VerifyKey] for requested_key_id in key_ids: if requested_key_id in keys: @@ -542,6 +550,11 @@ class Keyring(object): or server_name not in response[u"signatures"]): raise KeyLookupError("Key response not signed by remote server") + if response["server_name"] != server_name: + raise KeyLookupError("Expected a response for server %r not %r" % ( + server_name, response["server_name"] + )) + response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], @@ -550,24 +563,45 @@ class Keyring(object): keys.update(response_keys) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=key_server_name, - from_server=server_name, - verify_keys=verify_keys, - ) - for key_server_name, verify_keys in keys.items() - ], - consumeErrors=True - ).addErrback(unwrapFirstError)) - - defer.returnValue(keys) + yield self.store_keys( + server_name=server_name, + from_server=server_name, + verify_keys=keys, + ) + defer.returnValue({server_name: keys}) @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, - requested_ids=[], only_from_server=True): + def process_v2_response( + self, from_server, response_json, requested_ids=[], + ): + """Parse a 'Server Keys' structure from the result of a /key request + + This is used to parse either the entirety of the response from + GET /_matrix/key/v2/server, or a single entry from the list returned by + POST /_matrix/key/v2/query. + + Checks that each signature in the response that claims to come from the origin + server is valid. (Does not check that there actually is such a signature, for + some reason.) + + Stores the json in server_keys_json so that it can be used for future responses + to /_matrix/key/v2/query. + + Args: + from_server (str): the name of the server producing this result: either + the origin server for a /_matrix/key/v2/server request, or the notary + for a /_matrix/key/v2/query. + + response_json (dict): the json-decoded Server Keys response object + + requested_ids (iterable[str]): a list of the key IDs that were requested. + We will store the json for these key ids as well as any that are + actually in the response + + Returns: + Deferred[dict[str, nacl.signing.VerifyKey]]: + map from key_id to key object + """ time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -589,15 +623,7 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key - results = {} server_name = response_json["server_name"] - if only_from_server: - if server_name != from_server: - raise KeyLookupError( - "Expected a response for server %r not %r" % ( - from_server, server_name - ) - ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise KeyLookupError( @@ -633,7 +659,7 @@ class Keyring(object): self.store.store_server_keys_json, server_name=server_name, key_id=key_id, - from_server=server_name, + from_server=from_server, ts_now_ms=time_now_ms, ts_expires_ms=ts_valid_until_ms, key_json_bytes=signed_key_json_bytes, @@ -643,9 +669,7 @@ class Keyring(object): consumeErrors=True, ).addErrback(unwrapFirstError)) - results[server_name] = response_keys - - defer.returnValue(results) + defer.returnValue(response_keys) def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 8f9e330da5..203490fc36 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -20,15 +20,9 @@ from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventFormatVersions, - EventTypes, - JoinRules, - Membership, - RoomVersions, -) +from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) @@ -452,16 +446,18 @@ def check_redaction(room_version, event, auth_events): if user_level >= redact_level: return False - if room_version in (RoomVersions.V1, RoomVersions.V2,): + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: + raise RuntimeError("Unrecognized room version %r" % (room_version,)) + + if v.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: return True - elif room_version == RoomVersions.V3: + else: event.internal_metadata.recheck_redaction = True return True - else: - raise RuntimeError("Unrecognized room version %r" % (room_version,)) raise AuthError( 403, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index fafa135182..12056d5be2 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -21,7 +21,7 @@ import six from unpaddedbase64 import encode_base64 -from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -351,18 +351,13 @@ def room_version_to_event_format(room_version): Returns: int """ - if room_version not in KNOWN_ROOM_VERSIONS: + v = KNOWN_ROOM_VERSIONS.get(room_version) + + if not v: # We should have already checked version, so this should not happen raise RuntimeError("Unrecognized room version %s" % (room_version,)) - if room_version in ( - RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST, - ): - return EventFormatVersions.V1 - elif room_version in (RoomVersions.V3,): - return EventFormatVersions.V2 - else: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) + return v.event_format def event_type_from_format_version(format_version): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 06e01be918..fba27177c7 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,21 +17,17 @@ import attr from twisted.internet import defer -from synapse.api.constants import ( +from synapse.api.constants import MAX_DEPTH +from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, - MAX_DEPTH, EventFormatVersions, ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.types import EventID from synapse.util.stringutils import random_string -from . import ( - _EventInternalMetadata, - event_type_from_format_version, - room_version_to_event_format, -) +from . import _EventInternalMetadata, event_type_from_format_version @attr.s(slots=True, cmp=False, frozen=True) @@ -170,21 +166,34 @@ class EventBuilderFactory(object): def new(self, room_version, key_values): """Generate an event builder appropriate for the given room version + Deprecated: use for_room_version with a RoomVersion object instead + Args: - room_version (str): Version of the room that we're creating an - event builder for + room_version (str): Version of the room that we're creating an event builder + for key_values (dict): Fields used as the basis of the new event Returns: EventBuilder """ - - # There's currently only the one event version defined - if room_version not in KNOWN_ROOM_VERSIONS: + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: raise Exception( "No event format defined for version %r" % (room_version,) ) + return self.for_room_version(v, key_values) + def for_room_version(self, room_version, key_values): + """Generate an event builder appropriate for the given room version + + Args: + room_version (synapse.api.room_versions.RoomVersion): + Version of the room that we're creating an event builder for + key_values (dict): Fields used as the basis of the new event + + Returns: + EventBuilder + """ return EventBuilder( store=self.store, state=self.state, @@ -192,7 +201,7 @@ class EventBuilderFactory(object): clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, - format_version=room_version_to_event_format(room_version), + format_version=room_version.event_format, type=key_values["type"], state_key=key_values.get("state_key"), room_id=key_values["room_id"], @@ -222,7 +231,6 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, FrozenEvent """ - # There's currently only the one event version defined if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: raise Exception( "No event format defined for version %r" % (format_version,) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 633e068eb8..6058077f75 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd. +# Copyright 2017 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/events/validator.py b/synapse/events/validator.py index a072674b02..514273c792 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -15,8 +15,9 @@ from six import string_types -from synapse.api.constants import EventFormatVersions, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.room_versions import EventFormatVersions from synapse.types import EventID, RoomID, UserID diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a7a2ec4523..cffa831d80 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -20,8 +20,9 @@ import six from twisted.internet import defer from twisted.internet.defer import DeferredList -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.crypto.event_signing import check_event_content_hash from synapse.events import event_type_from_format_version from synapse.events.utils import prune_event @@ -268,15 +269,29 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): for p in pdus_to_check_sender ]) + def sender_err(e, pdu_to_check): + errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( + pdu_to_check.pdu.event_id, + pdu_to_check.sender_domain, + e.getErrorMessage(), + ) + # XX not really sure if these are the right codes, but they are what + # we've done for ages + raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + for p, d in zip(pdus_to_check_sender, more_deferreds): + d.addErrback(sender_err, p) p.deferreds.append(d) # now let's look for events where the sender's domain is different to the # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain - if room_version in ( - RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST, - ): + # (ie, the room version uses old-style non-hash event IDs). + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: + raise RuntimeError("Unrecognized room version %s" % (room_version,)) + + if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ p for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) @@ -287,12 +302,19 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): for p in pdus_to_check_event_id ]) + def event_err(e, pdu_to_check): + errmsg = ( + "event id %s: unable to verify signature for event id domain: %s" % ( + pdu_to_check.pdu.event_id, + e.getErrorMessage(), + ) + ) + # XX as above: not really sure if these are the right codes + raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + for p, d in zip(pdus_to_check_event_id, more_deferreds): + d.addErrback(event_err, p) p.deferreds.append(d) - elif room_version in (RoomVersions.V3,): - pass # No further checks needed, as event IDs are hashes here - else: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) # replace lists of deferreds with single Deferreds return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 58e04d81ab..f3fc897a0a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -25,12 +25,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventTypes, - Membership, - RoomVersions, -) +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -38,6 +33,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.util import logcontext, unwrapFirstError @@ -570,7 +570,7 @@ class FederationClient(FederationBase): Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of `(origin, event, event_format)` where origin is the remote homeserver which generated the event, and event_format is one of - `synapse.api.constants.EventFormatVersions`. + `synapse.api.room_versions.EventFormatVersions`. Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. @@ -592,7 +592,7 @@ class FederationClient(FederationBase): # Note: If not supplied, the room version may be either v1 or v2, # however either way the event format version will be v1. - room_version = ret.get("room_version", RoomVersions.V1) + room_version = ret.get("room_version", RoomVersions.V1.identifier) event_format = room_version_to_event_format(room_version) pdu_dict = ret.get("event", None) @@ -695,7 +695,9 @@ class FederationClient(FederationBase): room_version = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): - room_version = e.content.get("room_version", RoomVersions.V1) + room_version = e.content.get( + "room_version", RoomVersions.V1.identifier + ) break if room_version is None: @@ -802,11 +804,10 @@ class FederationClient(FederationBase): raise err # Otherwise, we assume that the remote server doesn't understand - # the v2 invite API. - - if room_version in (RoomVersions.V1, RoomVersions.V2): - pass # We'll fall through - else: + # the v2 invite API. That's ok provided the room uses old-style event + # IDs. + v = KNOWN_ROOM_VERSIONS.get(room_version) + if v.event_format != EventFormatVersions.V1: raise SynapseError( 400, "User's homeserver does not support this room version", diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 81f3b4b1ff..df60828dba 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -25,7 +25,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -34,6 +34,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.crypto.event_signing import compute_event_signature from synapse.events import room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 04d04a4457..0240b339b0 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object): self.is_mine_id = hs.is_mine_id self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = SortedDict() # Stream position -> user_id + self.presence_changed = SortedDict() # Stream position -> list[user_id] + + # Stores the destinations we need to explicitly send presence to about a + # given user. + # Stream position -> (user_id, destinations) + self.presence_destinations = SortedDict() self.keyed_edu = {} # (destination, key) -> EDU self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) @@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object): for user_id in uids ) + keys = self.presence_destinations.keys() + i = self.presence_destinations.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_destinations[key] + + user_ids.update( + user_id for user_id, _ in self.presence_destinations.values() + ) + to_del = [ user_id for user_id in self.presence_map if user_id not in user_ids ] @@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() + def send_presence_to_destinations(self, states, destinations): + """As per FederationSender + + Args: + states (list[UserPresenceState]) + destinations (list[str]) + """ + for state in states: + pos = self._next_pos() + self.presence_map.update({state.user_id: state for state in states}) + self.presence_destinations[pos] = (state.user_id, destinations) + + self.notifier.on_new_replication_data() + def send_device_messages(self, destination): """As per FederationSender""" pos = self._next_pos() @@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object): state=self.presence_map[user_id], ))) + # Fetch presence to send to destinations + i = self.presence_destinations.bisect_right(from_token) + j = self.presence_destinations.bisect_right(to_token) + 1 + + for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: + rows.append((pos, PresenceDestinationsRow( + state=self.presence_map[user_id], + destinations=list(dests), + ))) + # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) j = self.keyed_edu_changed.bisect_right(to_token) + 1 @@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( buff.presence.append(self.state) +class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( + "state", # UserPresenceState + "destinations", # list[str] +))): + TypeId = "pd" + + @staticmethod + def from_data(data): + return PresenceDestinationsRow( + state=UserPresenceState.from_dict(data["state"]), + destinations=data["dests"], + ) + + def to_data(self): + return { + "state": self.state.as_dict(), + "dests": self.destinations, + } + + def add_to_buffer(self, buff): + buff.presence_destinations.append((self.state, self.destinations)) + + class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( "key", # tuple(str) - the edu key passed to send_edu "edu", # Edu @@ -428,6 +489,7 @@ TypeToRow = { Row.TypeId: Row for Row in ( PresenceRow, + PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow, @@ -437,6 +499,7 @@ TypeToRow = { ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] "device_destinations", # set of destinations @@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows): buff = ParsedFederationStreamData( presence=[], + presence_destinations=[], keyed_edus={}, edus={}, device_destinations=set(), @@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows): if buff.presence: transaction_queue.send_presence(buff.presence) + for state, destinations in buff.presence_destinations: + transaction_queue.send_presence_to_destinations( + states=[state], destinations=destinations, + ) + for destination, edu_map in iteritems(buff.keyed_edus): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 1dc041752b..4f0f939102 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -371,7 +371,7 @@ class FederationSender(object): return # First we queue up the new presence by user ID, so multiple presence - # updates in quick successtion are correctly handled + # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. self.pending_presence.update({ @@ -402,6 +402,23 @@ class FederationSender(object): finally: self._processing_pending_presence = False + def send_presence_to_destinations(self, states, destinations): + """Send the given presence states to the given destinations. + + Args: + states (list[UserPresenceState]) + destinations (list[str]) + """ + + if not states or not self.hs.config.use_presence: + # No-op if presence is disabled. + return + + for destination in destinations: + if destination == self.server_name: + continue + self._get_per_destination_queue(destination).send_presence(states) + @measure_func("txnqueue._process_presence") @defer.inlineCallbacks def _process_presence_inner(self, states): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index efb6bdca48..452599e1a1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -21,8 +21,8 @@ import re from twisted.internet import defer import synapse -from synapse.api.constants import RoomVersions from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource @@ -513,7 +513,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1, + origin, content, room_version=RoomVersions.V1.identifier, ) # V1 federation API is defined to return a content of `[200, {...}]` diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index a7eaead56b..817be40360 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -22,6 +22,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id +from synapse.util.async_helpers import concurrently_execute logger = logging.getLogger(__name__) @@ -896,6 +897,78 @@ class GroupsServerHandler(object): "group_id": group_id, }) + @defer.inlineCallbacks + def delete_group(self, group_id, requester_user_id): + """Deletes a group, kicking out all current members. + + Only group admins or server admins can call this request + + Args: + group_id (str) + request_user_id (str) + + Returns: + Deferred + """ + + yield self.check_group_is_ours( + group_id, requester_user_id, + and_exists=True, + ) + + # Only server admins or group admins can delete groups. + + is_admin = yield self.store.is_user_admin_in_group( + group_id, requester_user_id + ) + + if not is_admin: + is_admin = yield self.auth.is_server_admin( + UserID.from_string(requester_user_id), + ) + + if not is_admin: + raise SynapseError(403, "User is not an admin") + + # Before deleting the group lets kick everyone out of it + users = yield self.store.get_users_in_group( + group_id, include_private=True, + ) + + @defer.inlineCallbacks + def _kick_user_from_group(user_id): + if self.hs.is_mine_id(user_id): + groups_local = self.hs.get_groups_local_handler() + yield groups_local.user_removed_from_group(group_id, user_id, {}) + else: + yield self.transport_client.remove_user_from_group_notification( + get_domain_from_id(user_id), group_id, user_id, {} + ) + yield self.store.maybe_delete_remote_profile_cache(user_id) + + # We kick users out in the order of: + # 1. Non-admins + # 2. Other admins + # 3. The requester + # + # This is so that if the deletion fails for some reason other admins or + # the requester still has auth to retry. + non_admins = [] + admins = [] + for u in users: + if u["user_id"] == requester_user_id: + continue + if u["is_admin"]: + admins.append(u["user_id"]) + else: + non_admins.append(u["user_id"]) + + yield concurrently_execute(_kick_user_from_group, non_admins, 10) + yield concurrently_execute(_kick_user_from_group, admins, 10) + yield _kick_user_from_group(requester_user_id) + + yield self.store.delete_group(group_id) + def _parse_join_policy_from_contents(content): """Given a content for a request, return the specified join policy or None diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py new file mode 100644 index 0000000000..261446517d --- /dev/null +++ b/synapse/handlers/account_validity.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import email.mime.multipart +import email.utils +import logging +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.types import UserID +from synapse.util import stringutils +from synapse.util.logcontext import make_deferred_yieldable + +try: + from synapse.push.mailer import load_jinja2_templates +except ImportError: + load_jinja2_templates = None + +logger = logging.getLogger(__name__) + + +class AccountValidityHandler(object): + def __init__(self, hs): + self.hs = hs + self.store = self.hs.get_datastore() + self.sendmail = self.hs.get_sendmail() + self.clock = self.hs.get_clock() + + self._account_validity = self.hs.config.account_validity + + if self._account_validity.renew_by_email_enabled and load_jinja2_templates: + # Don't do email-specific configuration if renewal by email is disabled. + try: + app_name = self.hs.config.email_app_name + + self._subject = self._account_validity.renew_email_subject % { + "app": app_name, + } + + self._from_string = self.hs.config.email_notif_from % { + "app": app_name, + } + except Exception: + # If substitution failed, fall back to the bare strings. + self._subject = self._account_validity.renew_email_subject + self._from_string = self.hs.config.email_notif_from + + self._raw_from = email.utils.parseaddr(self._from_string)[1] + + self._template_html, self._template_text = load_jinja2_templates( + config=self.hs.config, + template_html_name=self.hs.config.email_expiry_template_html, + template_text_name=self.hs.config.email_expiry_template_text, + ) + + # Check the renewal emails to send and send them every 30min. + self.clock.looping_call( + self.send_renewal_emails, + 30 * 60 * 1000, + ) + + @defer.inlineCallbacks + def send_renewal_emails(self): + """Gets the list of users whose account is expiring in the amount of time + configured in the ``renew_at`` parameter from the ``account_validity`` + configuration, and sends renewal emails to all of these users as long as they + have an email 3PID attached to their account. + """ + expiring_users = yield self.store.get_users_expiring_soon() + + if expiring_users: + for user in expiring_users: + yield self._send_renewal_email( + user_id=user["user_id"], + expiration_ts=user["expiration_ts_ms"], + ) + + @defer.inlineCallbacks + def send_renewal_email_to_user(self, user_id): + expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + yield self._send_renewal_email(user_id, expiration_ts) + + @defer.inlineCallbacks + def _send_renewal_email(self, user_id, expiration_ts): + """Sends out a renewal email to every email address attached to the given user + with a unique link allowing them to renew their account. + + Args: + user_id (str): ID of the user to send email(s) to. + expiration_ts (int): Timestamp in milliseconds for the expiration date of + this user's account (used in the email templates). + """ + addresses = yield self._get_email_addresses_for_user(user_id) + + # Stop right here if the user doesn't have at least one email address. + # In this case, they will have to ask their server admin to renew their + # account manually. + if not addresses: + return + + try: + user_display_name = yield self.store.get_profile_displayname( + UserID.from_string(user_id).localpart + ) + if user_display_name is None: + user_display_name = user_id + except StoreError: + user_display_name = user_id + + renewal_token = yield self._get_renewal_token(user_id) + url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( + self.hs.config.public_baseurl, + renewal_token, + ) + + template_vars = { + "display_name": user_display_name, + "expiration_ts": expiration_ts, + "url": url, + } + + html_text = self._template_html.render(**template_vars) + html_part = MIMEText(html_text, "html", "utf8") + + plain_text = self._template_text.render(**template_vars) + text_part = MIMEText(plain_text, "plain", "utf8") + + for address in addresses: + raw_to = email.utils.parseaddr(address)[1] + + multipart_msg = MIMEMultipart('alternative') + multipart_msg['Subject'] = self._subject + multipart_msg['From'] = self._from_string + multipart_msg['To'] = address + multipart_msg['Date'] = email.utils.formatdate() + multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg.attach(text_part) + multipart_msg.attach(html_part) + + logger.info("Sending renewal email to %s", address) + + yield make_deferred_yieldable(self.sendmail( + self.hs.config.email_smtp_host, + self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security + )) + + yield self.store.set_renewal_mail_status( + user_id=user_id, + email_sent=True, + ) + + @defer.inlineCallbacks + def _get_email_addresses_for_user(self, user_id): + """Retrieve the list of email addresses attached to a user's account. + + Args: + user_id (str): ID of the user to lookup email addresses for. + + Returns: + defer.Deferred[list[str]]: Email addresses for this account. + """ + threepids = yield self.store.user_get_threepids(user_id) + + addresses = [] + for threepid in threepids: + if threepid["medium"] == "email": + addresses.append(threepid["address"]) + + defer.returnValue(addresses) + + @defer.inlineCallbacks + def _get_renewal_token(self, user_id): + """Generates a 32-byte long random string that will be inserted into the + user's renewal email's unique link, then saves it into the database. + + Args: + user_id (str): ID of the user to generate a string for. + + Returns: + defer.Deferred[str]: The generated string. + + Raises: + StoreError(500): Couldn't generate a unique string after 5 attempts. + """ + attempts = 0 + while attempts < 5: + try: + renewal_token = stringutils.random_string(32) + yield self.store.set_renewal_token_for_user(user_id, renewal_token) + defer.returnValue(renewal_token) + except StoreError: + attempts += 1 + raise StoreError(500, "Couldn't generate a unique string as refresh string.") + + @defer.inlineCallbacks + def renew_account(self, renewal_token): + """Renews the account attached to a given renewal token by pushing back the + expiration date by the current validity period in the server's configuration. + + Args: + renewal_token (str): Token sent with the renewal request. + """ + user_id = yield self.store.get_user_from_renewal_token(renewal_token) + logger.debug("Renewing an account for user %s", user_id) + yield self.renew_account_for_user(user_id) + + @defer.inlineCallbacks + def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + """Renews the account attached to a given user by pushing back the + expiration date by the current validity period in the server's + configuration. + + Args: + renewal_token (str): Token sent with the renewal request. + expiration_ts (int): New expiration date. Defaults to now + validity period. + email_sent (bool): Whether an email has been sent for this validity period. + Defaults to False. + + Returns: + defer.Deferred[int]: New expiration date for this account, as a timestamp + in milliseconds since epoch. + """ + if expiration_ts is None: + expiration_ts = self.clock.time_msec() + self._account_validity.period + + yield self.store.set_account_validity_for_user( + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + ) + + defer.returnValue(expiration_ts) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4544de821d..aa5d89a9ac 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -912,7 +912,7 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def delete_threepid(self, user_id, medium, address): + def delete_threepid(self, user_id, medium, address, id_server=None): """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -920,6 +920,10 @@ class AuthHandler(BaseHandler): user_id (str) medium (str) address (str) + id_server (str|None): Use the given identity server when unbinding + any threepids. If None then will attempt to unbind using the + identity server specified when binding (if known). + Returns: Deferred[bool]: Returns True if successfully unbound the 3pid on @@ -937,6 +941,7 @@ class AuthHandler(BaseHandler): { 'medium': medium, 'address': address, + 'id_server': id_server, }, ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 97d3f31d98..6a91f7698e 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -43,12 +43,15 @@ class DeactivateAccountHandler(BaseHandler): hs.get_reactor().callWhenRunning(self._start_user_parting) @defer.inlineCallbacks - def deactivate_account(self, user_id, erase_data): + def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account Args: user_id (str): ID of user to be deactivated erase_data (bool): whether to GDPR-erase the user's data + id_server (str|None): Use the given identity server when unbinding + any threepids. If None then will attempt to unbind using the + identity server specified when binding (if known). Returns: Deferred[bool]: True if identity server supports removing @@ -74,6 +77,7 @@ class DeactivateAccountHandler(BaseHandler): { 'medium': threepid['medium'], 'address': threepid['address'], + 'id_server': id_server, }, ) identity_server_supports_unbinding &= result diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index fe128d9c88..50c587aa61 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,6 +36,7 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): + MAX_ALIAS_LENGTH = 255 def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) @@ -43,8 +44,10 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastore() self.config = hs.config self.enable_room_list_search = hs.config.enable_room_list_search + self.require_membership = hs.config.require_membership_for_aliases self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -68,7 +71,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.state.get_current_users_in_room(room_id) servers = set(get_domain_from_id(u) for u in users) if not servers: @@ -83,7 +86,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True): + send_event=True, check_membership=True): """Attempt to create a new alias Args: @@ -93,6 +96,8 @@ class DirectoryHandler(BaseHandler): servers (list[str]|None): List of servers that others servers should try and join via send_event (bool): Whether to send an updated m.room.aliases event + check_membership (bool): Whether to check if the user is in the room + before the alias can be set (if the server's config requires it). Returns: Deferred @@ -100,6 +105,13 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() + if len(room_alias.to_string()) > self.MAX_ALIAS_LENGTH: + raise SynapseError( + 400, + "Can't create aliases longer than %s characters" % self.MAX_ALIAS_LENGTH, + Codes.INVALID_PARAM, + ) + service = requester.app_service if service: if not service.is_interested_in_alias(room_alias.to_string()): @@ -108,6 +120,14 @@ class DirectoryHandler(BaseHandler): " this kind of alias.", errcode=Codes.EXCLUSIVE ) else: + if self.require_membership and check_membership: + rooms_for_user = yield self.store.get_rooms_for_user(user_id) + if room_id not in rooms_for_user: + raise AuthError( + 403, + "You must be in the room to create an alias for it", + ) + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): raise AuthError( 403, "This user is not permitted to create this alias", @@ -268,7 +288,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.state.get_current_users_in_room(room_id) extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d883e98381..1b4d8c74ae 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -102,7 +102,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_user_in_room(event.room_id) + users = yield self.state.get_current_users_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9eaf2d3e18..0684778882 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -29,13 +29,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventTypes, - Membership, - RejectedReason, - RoomVersions, -) +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, CodeMessageException, @@ -44,6 +38,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event from synapse.events.validator import EventValidator @@ -1733,7 +1728,9 @@ class FederationHandler(BaseHandler): # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - room_version = create_event.content.get("room_version", RoomVersions.V1) + room_version = create_event.content.get( + "room_version", RoomVersions.V1.identifier, + ) missing_auth_events = set() for e in itertools.chain(auth_events, state, [event]): diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 39184f0e22..22469486d7 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -132,6 +132,14 @@ class IdentityHandler(BaseHandler): } ) logger.debug("bound threepid %r to %s", creds, mxid) + + # Remember where we bound the threepid + yield self.store.add_user_bound_threepid( + user_id=mxid, + medium=data["medium"], + address=data["address"], + id_server=id_server, + ) except CodeMessageException as e: data = json.loads(e.msg) # XXX WAT? defer.returnValue(data) @@ -142,30 +150,61 @@ class IdentityHandler(BaseHandler): Args: mxid (str): Matrix user ID of binding to be removed - threepid (dict): Dict with medium & address of binding to be removed + threepid (dict): Dict with medium & address of binding to be + removed, and an optional id_server. Raises: SynapseError: If we failed to contact the identity server Returns: Deferred[bool]: True on success, otherwise False if the identity - server doesn't support unbinding + server doesn't support unbinding (or no identity server found to + contact). """ - logger.debug("unbinding threepid %r from %s", threepid, mxid) - if not self.trusted_id_servers: - logger.warn("Can't unbind threepid: no trusted ID servers set in config") + if threepid.get("id_server"): + id_servers = [threepid["id_server"]] + else: + id_servers = yield self.store.get_id_servers_user_bound( + user_id=mxid, + medium=threepid["medium"], + address=threepid["address"], + ) + + # We don't know where to unbind, so we don't have a choice but to return + if not id_servers: defer.returnValue(False) - # We don't track what ID server we added 3pids on (perhaps we ought to) - # but we assume that any of the servers in the trusted list are in the - # same ID server federation, so we can pick any one of them to send the - # deletion request to. - id_server = next(iter(self.trusted_id_servers)) + changed = True + for id_server in id_servers: + changed &= yield self.try_unbind_threepid_with_id_server( + mxid, threepid, id_server, + ) + + defer.returnValue(changed) + + @defer.inlineCallbacks + def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): + """Removes a binding from an identity server + Args: + mxid (str): Matrix user ID of binding to be removed + threepid (dict): Dict with medium & address of binding to be removed + id_server (str): Identity server to unbind from + + Raises: + SynapseError: If we failed to contact the identity server + + Returns: + Deferred[bool]: True on success, otherwise False if the identity + server doesn't support unbinding + """ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, - "threepid": threepid, + "threepid": { + "medium": threepid["medium"], + "address": threepid["address"], + }, } # we abuse the federation http client to sign the request, but we have to send it @@ -188,16 +227,24 @@ class IdentityHandler(BaseHandler): content, headers, ) + changed = True except HttpResponseException as e: + changed = False if e.code in (400, 404, 501,): # The remote server probably doesn't support unbinding (yet) logger.warn("Received %d response while unbinding threepid", e.code) - defer.returnValue(False) else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(502, "Failed to contact identity server") - defer.returnValue(True) + yield self.store.remove_user_bound_threepid( + user_id=mxid, + medium=threepid["medium"], + address=threepid["address"], + id_server=id_server, + ) + + defer.returnValue(changed) @defer.inlineCallbacks def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9b41c7b205..224d34ef3a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -30,6 +30,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) +from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -191,7 +192,7 @@ class MessageHandler(object): "Getting joined members after leaving is not implemented" ) - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.state.get_current_users_in_room(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -603,7 +604,9 @@ class EventCreationHandler(object): """ if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get("room_version", RoomVersions.V1) + room_version = event.content.get( + "room_version", RoomVersions.V1.identifier + ) else: room_version = yield self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37e87fc054..59d53f1050 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -31,9 +31,11 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import PresenceState +import synapse.metrics +from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -98,6 +100,7 @@ class PresenceHandler(object): self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id + self.server_name = hs.hostname self.clock = hs.get_clock() self.store = hs.get_datastore() self.wheel_timer = WheelTimer() @@ -110,30 +113,6 @@ class PresenceHandler(object): federation_registry.register_edu_handler( "m.presence", self.incoming_presence ) - federation_registry.register_edu_handler( - "m.presence_invite", - lambda origin, content: self.invite_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) - federation_registry.register_edu_handler( - "m.presence_accept", - lambda origin, content: self.accept_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) - federation_registry.register_edu_handler( - "m.presence_deny", - lambda origin, content: self.deny_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) - - distributor = hs.get_distributor() - distributor.observe("user_joined_room", self.user_joined_room) active_presence = self.store.take_presence_startup_info() @@ -220,6 +199,15 @@ class PresenceHandler(object): LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], lambda: len(self.wheel_timer)) + # Used to handle sending of presence to newly joined users/servers + if hs.config.use_presence: + self.notifier.add_replication_callback(self.notify_new_event) + + # Presence is best effort and quickly heals itself, so lets just always + # stream from the current state when we restart. + self._event_pos = self.store.get_current_events_token() + self._event_processing = False + @defer.inlineCallbacks def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that @@ -751,199 +739,178 @@ class PresenceHandler(object): yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks - def user_joined_room(self, user, room_id): - """Called (via the distributor) when a user joins a room. This funciton - sends presence updates to servers, either: - 1. the joining user is a local user and we send their presence to - all servers in the room. - 2. the joining user is a remote user and so we send presence for all - local users in the room. + def is_visible(self, observed_user, observer_user): + """Returns whether a user can see another user's presence. """ - # We only need to send presence to servers that don't have it yet. We - # don't need to send to local clients here, as that is done as part - # of the event stream/sync. - # TODO: Only send to servers not already in the room. - if self.is_mine(user): - state = yield self.current_state_for_user(user.to_string()) - - self._push_to_remotes([state]) - else: - user_ids = yield self.store.get_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + observer_room_ids = yield self.store.get_rooms_for_user( + observer_user.to_string() + ) + observed_room_ids = yield self.store.get_rooms_for_user( + observed_user.to_string() + ) - states = yield self.current_state_for_users(user_ids) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - self._push_to_remotes(list(states.values())) + defer.returnValue(False) @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Returns the presence for all users in their presence list. + def get_all_presence_updates(self, last_id, current_id): """ - if not self.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) + Gets a list of presence update rows from between the given stream ids. + Each row has: + - stream_id(str) + - user_id(str) + - state(str) + - last_active_ts(int) + - last_federation_update_ts(int) + - last_user_sync_ts(int) + - status_msg(int) + - currently_active(int) + """ + # TODO(markjh): replicate the unpersisted changes. + # This could use the in-memory stores for recent changes. + rows = yield self.store.get_all_presence_updates(last_id, current_id) + defer.returnValue(rows) - results = yield self.get_states( - target_user_ids=[row["observed_user_id"] for row in presence_list], - as_event=False, - ) + def notify_new_event(self): + """Called when new events have happened. Handles users and servers + joining rooms and require being sent presence. + """ - now = self.clock.time_msec() - results[:] = [format_user_presence_state(r, now) for r in results] + if self._event_processing: + return - is_accepted = { - row["observed_user_id"]: row["accepted"] for row in presence_list - } + @defer.inlineCallbacks + def _process_presence(): + assert not self._event_processing - for result in results: - result.update({ - "accepted": is_accepted, - }) + self._event_processing = True + try: + yield self._unsafe_process() + finally: + self._event_processing = False - defer.returnValue(results) + run_as_background_process("presence.notify_new_event", _process_presence) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Sends a presence invite. - """ - yield self.store.add_presence_list_pending( - observer_user.localpart, observed_user.to_string() - ) + def _unsafe_process(self): + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "presence_delta"): + deltas = yield self.store.get_current_state_deltas(self._event_pos) + if not deltas: + return - if self.is_mine(observed_user): - yield self.invite_presence(observed_user, observer_user) - else: - yield self.federation.build_and_send_edu( - destination=observed_user.domain, - edu_type="m.presence_invite", - content={ - "observed_user": observed_user.to_string(), - "observer_user": observer_user.to_string(), - } - ) + yield self._handle_state_delta(deltas) + + self._event_pos = deltas[-1]["stream_id"] + + # Expose current event processing position to prometheus + synapse.metrics.event_processing_positions.labels("presence").set( + self._event_pos + ) @defer.inlineCallbacks - def invite_presence(self, observed_user, observer_user): - """Handles new presence invites. + def _handle_state_delta(self, deltas): + """Process current state deltas to find new joins that need to be + handled. """ - if not self.is_mine(observed_user): - raise SynapseError(400, "User is not hosted on this Home Server") + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + prev_event_id = delta["prev_event_id"] - # TODO: Don't auto accept - if self.is_mine(observer_user): - yield self.accept_presence(observed_user, observer_user) - else: - self.federation.build_and_send_edu( - destination=observer_user.domain, - edu_type="m.presence_accept", - content={ - "observed_user": observed_user.to_string(), - "observer_user": observer_user.to_string(), - } - ) + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) - state_dict = yield self.get_state(observed_user, as_event=False) - state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) + if typ != EventTypes.Member: + continue - self.federation.build_and_send_edu( - destination=observer_user.domain, - edu_type="m.presence", - content={ - "push": [state_dict] - } - ) + if event_id is None: + # state has been deleted, so this is not a join. We only care about + # joins. + continue - @defer.inlineCallbacks - def accept_presence(self, observed_user, observer_user): - """Handles a m.presence_accept EDU. Mark a presence invite from a - local or remote user as accepted in a local user's presence list. - Starts polling for presence updates from the local or remote user. - Args: - observed_user(UserID): The user to update in the presence list. - observer_user(UserID): The owner of the presence list to update. - """ - yield self.store.set_presence_list_accepted( - observer_user.localpart, observed_user.to_string() - ) + event = yield self.store.get_event(event_id) + if event.content.get("membership") != Membership.JOIN: + # We only care about joins + continue - @defer.inlineCallbacks - def deny_presence(self, observed_user, observer_user): - """Handle a m.presence_deny EDU. Removes a local or remote user from a - local user's presence list. - Args: - observed_user(UserID): The local or remote user to remove from the - list. - observer_user(UserID): The local owner of the presence list. - Returns: - A Deferred. - """ - yield self.store.del_presence_list( - observer_user.localpart, observed_user.to_string() - ) + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id) + if prev_event.content.get("membership") == Membership.JOIN: + # Ignore changes to join events. + continue - # TODO(paul): Inform the user somehow? + yield self._on_user_joined_room(room_id, state_key) @defer.inlineCallbacks - def drop(self, observed_user, observer_user): - """Remove a local or remote user from a local user's presence list and - unsubscribe the local user from updates that user. + def _on_user_joined_room(self, room_id, user_id): + """Called when we detect a user joining the room via the current state + delta stream. + Args: - observed_user(UserId): The local or remote user to remove from the - list. - observer_user(UserId): The local owner of the presence list. + room_id (str) + user_id (str) + Returns: - A Deferred. + Deferred """ - if not self.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - yield self.store.del_presence_list( - observer_user.localpart, observed_user.to_string() - ) + if self.is_mine_id(user_id): + # If this is a local user then we need to send their presence + # out to hosts in the room (who don't already have it) - # TODO: Inform the remote that we've dropped the presence list. + # TODO: We should be able to filter the hosts down to those that + # haven't previously seen the user - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): - """Returns whether a user can see another user's presence. - """ - observer_room_ids = yield self.store.get_rooms_for_user( - observer_user.to_string() - ) - observed_room_ids = yield self.store.get_rooms_for_user( - observed_user.to_string() - ) + state = yield self.current_state_for_user(user_id) + hosts = yield self.state.get_current_hosts_in_room(room_id) - if observer_room_ids & observed_room_ids: - defer.returnValue(True) + # Filter out ourselves. + hosts = set(host for host in hosts if host != self.server_name) - accepted_observers = yield self.store.get_presence_list_observers_accepted( - observed_user.to_string() - ) + self.federation.send_presence_to_destinations( + states=[state], + destinations=hosts, + ) + else: + # A remote user has joined the room, so we need to: + # 1. Check if this is a new server in the room + # 2. If so send any presence they don't already have for + # local users in the room. - defer.returnValue(observer_user.to_string() in accepted_observers) + # TODO: We should be able to filter the users down to those that + # the server hasn't previously seen - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): - """ - Gets a list of presence update rows from between the given stream ids. - Each row has: - - stream_id(str) - - user_id(str) - - state(str) - - last_active_ts(int) - - last_federation_update_ts(int) - - last_user_sync_ts(int) - - status_msg(int) - - currently_active(int) - """ - # TODO(markjh): replicate the unpersisted changes. - # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) - defer.returnValue(rows) + # TODO: Check that this is actually a new server joining the + # room. + + user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, user_ids)) + + states = yield self.current_state_for_users(user_ids) + + # Filter out old presence, i.e. offline presence states where + # the user hasn't been active for a week. We can change this + # depending on what we want the UX to be, but at the least we + # should filter out offline presence where the state is just the + # default state. + now = self.clock.time_msec() + states = [ + state for state in states.values() + if state.state != PresenceState.OFFLINE + or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 + or state.status_msg is not None + ] + + if states: + self.federation.send_presence_to_destinations( + states=states, + destinations=[get_domain_from_id(user_id)], + ) def should_notify(old_state, new_state): @@ -1086,10 +1053,7 @@ class PresenceEventSource(object): updates for """ user_id = user.to_string() - plist = yield self.store.get_presence_list_accepted( - user.localpart, on_invalidate=cache_context.invalidate, - ) - users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( @@ -1294,10 +1258,6 @@ def get_interested_parties(store, states): for room_id in room_ids: room_ids_to_states.setdefault(room_id, []).append(state) - plist = yield store.get_presence_list_observers_accepted(state.user_id) - for u in plist: - users_to_states.setdefault(u, []).append(state) - # Always notify self users_to_states.setdefault(state.user_id, []).append(state) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 58940e0320..a51d11a257 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler): user_type=None, default_display_name=None, address=None, + bind_emails=[], ): """Registers a new client on the server. @@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler): default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. address (str|None): the IP address used to perform the registration. + bind_emails (List[str]): list of emails to bind to this account. Returns: A tuple of (user_id, access_token). Raises: @@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler): if not self.hs.config.user_consent_at_registration: yield self._auto_join_rooms(user_id) + # Bind any specified emails to this account + current_time = self.hs.get_clock().time_msec() + for email in bind_emails: + # generate threepid dict + threepid_dict = { + "medium": "email", + "address": email, + "validated_at": current_time, + } + + # Bind email to new account + yield self._register_email_threepid( + user_id, threepid_dict, None, False, + ) + defer.returnValue((user_id, token)) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 67b15697fd..e37ae96899 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,14 +25,9 @@ from six import iteritems, string_types from twisted.internet import defer -from synapse.api.constants import ( - DEFAULT_ROOM_VERSION, - KNOWN_ROOM_VERSIONS, - EventTypes, - JoinRules, - RoomCreationPreset, -) +from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -285,6 +280,7 @@ class RoomCreationHandler(BaseHandler): (EventTypes.RoomAvatar, ""), (EventTypes.Encryption, ""), (EventTypes.ServerACL, ""), + (EventTypes.RelatedGroups, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( @@ -406,7 +402,7 @@ class RoomCreationHandler(BaseHandler): yield directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, servers=(self.hs.hostname, ), - send_event=False, + send_event=False, check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: @@ -479,7 +475,7 @@ class RoomCreationHandler(BaseHandler): if ratelimit: yield self.ratelimit(requester) - room_version = config.get("room_version", DEFAULT_ROOM_VERSION) + room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier) if not isinstance(room_version, string_types): raise SynapseError( 400, @@ -542,6 +538,7 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, servers=[self.hs.hostname], send_event=False, + check_membership=False, ) preset_config = config.get( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index d6c9d56007..617d1c9ef8 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -167,7 +167,7 @@ class RoomListHandler(BaseHandler): if not latest_event_ids: return - joined_users = yield self.state_handler.get_current_user_in_room( + joined_users = yield self.state_handler.get_current_users_in_room( room_id, latest_event_ids, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 71ce5b54e5..024d6db27a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -70,6 +70,7 @@ class RoomMemberHandler(object): self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self._server_notices_mxid = self.config.server_notices_mxid + self._enable_lookup = hs.config.enable_3pid_lookup @abc.abstractmethod def _remote_join(self, requester, remote_room_hosts, room_id, user, content): @@ -421,6 +422,9 @@ class RoomMemberHandler(object): room_id, latest_event_ids=latest_event_ids, ) + # TODO: Refactor into dictionary of explicitly allowed transitions + # between old and new state, with specific error messages for some + # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: old_state = yield self.store.get_event(old_state_id, allow_none=True) @@ -446,6 +450,9 @@ class RoomMemberHandler(object): if same_sender and same_membership and same_content: defer.returnValue(old_state) + if old_membership in ["ban", "leave"] and action == "kick": + raise AuthError(403, "The target user is not in the room") + # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( @@ -459,6 +466,9 @@ class RoomMemberHandler(object): "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) + else: + if action == "kick": + raise AuthError(403, "The target user is not in the room") is_host_in_room = yield self._is_host_in_room(current_state_ids) @@ -729,6 +739,10 @@ class RoomMemberHandler(object): Returns: str: the matrix ID of the 3pid, or None if it is not recognized. """ + if not self._enable_lookup: + raise SynapseError( + 403, "Looking up third-party identifiers is denied from this server", + ) try: data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 57bb996245..153312e39f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1049,11 +1049,11 @@ class SyncHandler(object): # TODO: Be more clever than this, i.e. remove users who we already # share a room with? for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_user_in_room(room_id) + joined_users = yield self.state.get_current_users_in_room(room_id) newly_joined_users.update(joined_users) for room_id in newly_left_rooms: - left_users = yield self.state.get_current_user_in_room(room_id) + left_users = yield self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # TODO: Check that these users are actually new, i.e. either they @@ -1213,7 +1213,7 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) @@ -1855,7 +1855,7 @@ class SyncHandler(object): extrems = yield self.store.get_forward_extremeties_for_room( room_id, stream_ordering, ) - users_in_room = yield self.state.get_current_user_in_room( + users_in_room = yield self.state.get_current_users_in_room( room_id, extrems, ) if user_id in users_in_room: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 39df960c31..972662eb48 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -218,7 +218,7 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_remote(self, member, typing): try: - users = yield self.state.get_current_user_in_room(member.room_id) + users = yield self.state.get_current_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -261,7 +261,7 @@ class TypingHandler(object): ) return - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.state.get_current_users_in_room(room_id) domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b689979b4b..5de9630950 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -276,7 +276,7 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. for user_id in iterkeys(users_with_profile): @@ -325,7 +325,7 @@ class UserDirectoryHandler(StateDeltasHandler): room_id ) # Now we update users who share rooms with users. - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.state.get_current_users_in_room(room_id) if is_public: yield self.store.add_users_in_public_rooms(room_id, (user_id,)) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1334c630cc..b4cbe97b41 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -149,7 +149,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + res.tls_server_name.decode("ascii"), ) # make sure that the Host header is set correctly diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 235ce8334e..b3abd1b3c6 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -74,14 +74,14 @@ class ModuleApi(object): return self._auth_handler.check_user_exists(user_id) @defer.inlineCallbacks - def register(self, localpart, displayname=None): + def register(self, localpart, displayname=None, emails=[]): """Registers a new user with given localpart and optional - displayname. + displayname, emails. Args: localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. If None, - the user's displayname will default to `localpart`. + displayname (str|None): The displayname of the new user. + emails (List[str]): Emails to bind to the new user. Returns: Deferred: a 2-tuple of (user_id, access_token) @@ -90,6 +90,7 @@ class ModuleApi(object): reg = self.hs.get_registration_handler() user_id, access_token = yield reg.register( localpart=localpart, default_display_name=displayname, + bind_emails=emails, ) defer.returnValue((user_id, access_token)) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 8f0682c948..3523a40108 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -261,6 +261,23 @@ BASE_APPEND_OVERRIDE_RULES = [ 'value': True, } ] + }, + { + 'rule_id': 'global/override/.m.rule.tombstone', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.tombstone', + '_id': '_tombstone', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': True, + } + ] } ] diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 50e1007d84..e8ee67401f 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -72,8 +72,15 @@ class EmailPusher(object): self._is_processing = False - def on_started(self): - if self.mailer is not None: + def on_started(self, should_check_for_notifs): + """Called when this pusher has been started. + + Args: + should_check_for_notifs (bool): Whether we should immediately + check for push to send. Set to False only if it's known there + is nothing to send + """ + if should_check_for_notifs and self.mailer is not None: self._start_processing() def on_stop(self): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e65f8c63d3..fac05aa44c 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -112,8 +112,16 @@ class HttpPusher(object): self.data_minus_url.update(self.data) del self.data_minus_url['url'] - def on_started(self): - self._start_processing() + def on_started(self, should_check_for_notifs): + """Called when this pusher has been started. + + Args: + should_check_for_notifs (bool): Whether we should immediately + check for push to send. Set to False only if it's known there + is nothing to send + """ + if should_check_for_notifs: + self._start_processing() def on_new_notifications(self, min_stream_ordering, max_stream_ordering): self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 1eb5be0957..c269bcf4a4 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -521,11 +521,11 @@ def format_ts_filter(value, format): return time.strftime(format, time.localtime(value / 1000)) -def load_jinja2_templates(config): +def load_jinja2_templates(config, template_html_name, template_text_name): """Load the jinja2 email templates from disk Returns: - (notif_template_html, notif_template_text) + (template_html, template_text) """ logger.info("loading email templates from '%s'", config.email_template_dir) loader = jinja2.FileSystemLoader(config.email_template_dir) @@ -533,14 +533,10 @@ def load_jinja2_templates(config): env.filters["format_ts"] = format_ts_filter env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) - notif_template_html = env.get_template( - config.email_notif_template_html - ) - notif_template_text = env.get_template( - config.email_notif_template_text - ) + template_html = env.get_template(template_html_name) + template_text = env.get_template(template_text_name) - return notif_template_html, notif_template_text + return template_html, template_text def _create_mxc_to_http_filter(config): diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index b33f2a357b..14bc7823cf 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -44,7 +44,11 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - templates = load_jinja2_templates(hs.config) + templates = load_jinja2_templates( + config=hs.config, + template_html_name=hs.config.email_notif_template_html, + template_text_name=hs.config.email_notif_template_text, + ) self.notif_template_html, self.notif_template_text = templates self.pusher_types["email"] = self._create_email_pusher diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index abf1a1a9c1..40a7709c09 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.pusher import PusherFactory +from synapse.util.async_helpers import concurrently_execute logger = logging.getLogger(__name__) @@ -197,7 +198,7 @@ class PusherPool: p = r if p: - self._start_pusher(p) + yield self._start_pusher(p) @defer.inlineCallbacks def _start_pushers(self): @@ -208,10 +209,14 @@ class PusherPool: """ pushers = yield self.store.get_all_pushers() logger.info("Starting %d pushers", len(pushers)) - for pusherdict in pushers: - self._start_pusher(pusherdict) + + # Stagger starting up the pushers so we don't completely drown the + # process on start up. + yield concurrently_execute(self._start_pusher, pushers, 10) + logger.info("Started pushers") + @defer.inlineCallbacks def _start_pusher(self, pusherdict): """Start the given pusher @@ -248,7 +253,22 @@ class PusherPool: if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p - p.on_started() + + # Check if there *may* be push to process. We do this as this check is a + # lot cheaper to do than actually fetching the exact rows we need to + # push. + user_id = pusherdict["user_name"] + last_stream_ordering = pusherdict["last_stream_ordering"] + if last_stream_ordering: + have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + user_id, last_stream_ordering, + ) + else: + # We always want to default to starting up the pusher rather than + # risk missing push. + have_notifs = True + + p.on_started(have_notifs) @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index c75119a030..2708f5e820 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -82,7 +82,9 @@ REQUIREMENTS = [ CONDITIONAL_REQUIREMENTS = { "email.enable_notifs": ["Jinja2>=2.9", "bleach>=1.4.2"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - "postgres": ["psycopg2>=2.6"], + + # we use execute_batch, which arrived in psycopg 2.7. + "postgres": ["psycopg2>=2.7"], # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], @@ -92,18 +94,22 @@ CONDITIONAL_REQUIREMENTS = { "acme": ["txacme>=0.9.2"], "saml2": ["pysaml2>=4.5.0"], + "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0", "parameterized"], "sentry": ["sentry-sdk>=0.7.2"], } +ALL_OPTIONAL_REQUIREMENTS = set() -def list_requirements(): - deps = set(REQUIREMENTS) - for opt in CONDITIONAL_REQUIREMENTS.values(): - deps = set(opt) | deps +for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): + # Exclude systemd as it's a system-based requirement. + if name not in ["systemd"]: + ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS - return list(deps) + +def list_requirements(): + return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) class DependencyException(Exception): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 4830c68f35..b457c5563f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -16,6 +16,10 @@ import logging from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore @@ -79,11 +83,7 @@ class SlavedEventStore(EventFederationWorkerStore, if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: - self.invalidate_caches_for_event( - token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, - backfilled=False, - ) + self._process_event_stream_row(token, row) elif stream_name == "backfill": self._backfill_id_gen.advance(-token) for row in rows: @@ -96,6 +96,23 @@ class SlavedEventStore(EventFederationWorkerStore, stream_name, token, rows ) + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self.invalidate_caches_for_event( + token, data.event_id, data.room_id, data.type, data.state_key, + data.redacts, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key, ), + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type, )) + def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, etype, state_key, redacts, backfilled): self._invalidate_get_event_cache(event_id) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 8032f53fec..cc6f7f009f 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,22 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore -from synapse.storage.keys import KeyStore +from synapse.storage import KeyStore -from ._base import BaseSlavedStore, __func__ +# KeyStore isn't really safe to use from a worker, but for now we do so and hope that +# the races it creates aren't too bad. - -class SlavedKeyStore(BaseSlavedStore): - _get_server_verify_key = KeyStore.__dict__[ - "_get_server_verify_key" - ] - - get_server_verify_keys = __func__(DataStore.get_server_verify_keys) - store_server_verify_key = __func__(DataStore.store_server_verify_key) - - get_server_certificate = __func__(DataStore.get_server_certificate) - store_server_certificate = __func__(DataStore.store_server_certificate) - - get_server_keys_json = __func__(DataStore.get_server_keys_json) - store_server_keys_json = __func__(DataStore.store_server_keys_json) +SlavedKeyStore = KeyStore diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 9e530defe0..0ec1db25ce 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -39,16 +39,6 @@ class SlavedPresenceStore(BaseSlavedStore): _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] - # XXX: This is a bit broken because we don't persist the accepted list in a - # way that can be replicated. This means that we don't have a way to - # invalidate the cache correctly. - get_presence_list_accepted = PresenceStore.__dict__[ - "get_presence_list_accepted" - ] - get_presence_list_observers_accepted = PresenceStore.__dict__[ - "get_presence_list_observers_accepted" - ] - def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e558f90e1a..206dc3b397 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -103,10 +103,19 @@ class ReplicationClientHandler(object): hs.get_reactor().connectTCP(host, port, self.factory) def on_rdata(self, stream_name, token, rows): - """Called when we get new replication data. By default this just pokes - the slave store. + """Called to handle a batch of replication data with a given stream token. - Can be overriden in subclasses to handle more. + By default this just pokes the slave store. Can be overridden in subclasses to + handle more. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) return self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 02e5bf6cc8..b51590cf8f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -42,8 +42,8 @@ indicate which side is sending, these are *not* included on the wire:: > POSITION backfill 1 > POSITION caches 1 > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", - "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + > RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]] < PING 1490197675618 > ERROR server stopping * connection closed by server * @@ -605,7 +605,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): inbound_rdata_count.labels(stream_name).inc() try: - row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 7fc346c7b6..f6a38f5140 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -30,7 +30,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, FederationStream +from .streams import STREAMS_MAP +from .streams.federation import FederationStream stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py new file mode 100644 index 0000000000..634f636dc9 --- /dev/null +++ b/synapse/replication/tcp/streams/__init__.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines all the valid streams that clients can subscribe to, and the format +of the rows returned by each stream. + +Each stream is defined by the following information: + + stream name: The name of the stream + row type: The type that is used to serialise/deserialse the row + current_token: The function that returns the current token for the stream + update_function: The function that returns a list of updates between two tokens +""" + +from . import _base, events, federation + +STREAMS_MAP = { + stream.NAME: stream + for stream in ( + events.EventsStream, + _base.BackfillStream, + _base.PresenceStream, + _base.TypingStream, + _base.ReceiptsStream, + _base.PushRulesStream, + _base.PushersStream, + _base.CachesStream, + _base.PublicRoomsStream, + _base.DeviceListsStream, + _base.ToDeviceStream, + federation.FederationStream, + _base.TagAccountDataStream, + _base.AccountDataStream, + _base.GroupServerStream, + ) +} diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams/_base.py index e23084baae..8971a6a22e 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams/_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Defines all the valid streams that clients can subscribe to, and the format -of the rows returned by each stream. -Each stream is defined by the following information: - - stream name: The name of the stream - row type: The type that is used to serialise/deserialse the row - current_token: The function that returns the current token for the stream - update_function: The function that returns a list of updates between two tokens -""" import itertools import logging from collections import namedtuple @@ -34,14 +26,6 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 - -EventStreamRow = namedtuple("EventStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional -)) BackfillStreamRow = namedtuple("BackfillStreamRow", ( "event_id", # str "room_id", # str @@ -96,10 +80,6 @@ DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( "entity", # str )) -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( "user_id", # str "room_id", # str @@ -111,12 +91,6 @@ AccountDataStreamRow = namedtuple("AccountDataStream", ( "data_type", # str "data", # dict )) -CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( - "room_id", # str - "type", # str - "state_key", # str - "event_id", # str, optional -)) GroupsStreamRow = namedtuple("GroupsStreamRow", ( "group_id", # str "user_id", # str @@ -132,9 +106,24 @@ class Stream(object): time it was called up until the point `advance_current_token` was called. """ NAME = None # The name of the stream - ROW_TYPE = None # The type of the row + ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit + @classmethod + def parse_row(cls, row): + """Parse a row received over replication + + By default, assumes that the row data is an array object and passes its contents + to the constructor of the ROW_TYPE for this stream. + + Args: + row: row data from the incoming RDATA command, after json decoding + + Returns: + ROW_TYPE object for this stream + """ + return cls.ROW_TYPE(*row) + def __init__(self, hs): # The token from which we last asked for updates self.last_token = self.current_token() @@ -162,8 +151,10 @@ class Stream(object): until the `upto_token` Returns: - (list(ROW_TYPE), int): list of updates plus the token used as an - upper bound of the updates (i.e. the "current token") + Deferred[Tuple[List[Tuple[int, Any]], int]: + Resolves to a pair ``(updates, current_token)``, where ``updates`` is a + list of ``(token, row)`` entries. ``row`` will be json-serialised and + sent over the replication steam. """ updates, current_token = yield self.get_updates_since(self.last_token) self.last_token = current_token @@ -176,8 +167,10 @@ class Stream(object): stream updates Returns: - (list(ROW_TYPE), int): list of updates plus the token used as an - upper bound of the updates (i.e. the "current token") + Deferred[Tuple[List[Tuple[int, Any]], int]: + Resolves to a pair ``(updates, current_token)``, where ``updates`` is a + list of ``(token, row)`` entries. ``row`` will be json-serialised and + sent over the replication steam. """ if from_token in ("NOW", "now"): defer.returnValue(([], self.upto_token)) @@ -202,7 +195,7 @@ class Stream(object): from_token, current_token, ) - updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + updates = [(row[0], row[1:]) for row in rows] # check we didn't get more rows than the limit. # doing it like this allows the update_function to be a generator. @@ -232,20 +225,6 @@ class Stream(object): raise NotImplementedError() -class EventsStream(Stream): - """We received a new event, or an event went from being an outlier to not - """ - NAME = "events" - ROW_TYPE = EventStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - self.current_token = store.get_current_events_token - self.update_function = store.get_all_new_forward_event_rows - - super(EventsStream, self).__init__(hs) - - class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -400,22 +379,6 @@ class ToDeviceStream(Stream): super(ToDeviceStream, self).__init__(hs) -class FederationStream(Stream): - """Data to be sent over federation. Only available when master has federation - sending disabled. - """ - NAME = "federation" - ROW_TYPE = FederationStreamRow - - def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows - - super(FederationStream, self).__init__(hs) - - class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ @@ -459,21 +422,6 @@ class AccountDataStream(Stream): defer.returnValue(results) -class CurrentStateDeltaStream(Stream): - """Current state for a room was changed - """ - NAME = "current_state_deltas" - ROW_TYPE = CurrentStateDeltaStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - - self.current_token = store.get_max_current_state_delta_stream_id - self.update_function = store.get_all_updated_current_state_deltas - - super(CurrentStateDeltaStream, self).__init__(hs) - - class GroupServerStream(Stream): NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -485,26 +433,3 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) - - -STREAMS_MAP = { - stream.NAME: stream - for stream in ( - EventsStream, - BackfillStream, - PresenceStream, - TypingStream, - ReceiptsStream, - PushRulesStream, - PushersStream, - CachesStream, - PublicRoomsStream, - DeviceListsStream, - ToDeviceStream, - FederationStream, - TagAccountDataStream, - AccountDataStream, - CurrentStateDeltaStream, - GroupServerStream, - ) -} diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py new file mode 100644 index 0000000000..e0f6e29248 --- /dev/null +++ b/synapse/replication/tcp/streams/events.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import heapq + +import attr + +from twisted.internet import defer + +from ._base import Stream + + +"""Handling of the 'events' replication stream + +This stream contains rows of various types. Each row therefore contains a 'type' +identifier before the real data. For example:: + + RDATA events batch ["state", ["!room:id", "m.type", "", "$event:id"]] + RDATA events 12345 ["ev", ["$event:id", "!room:id", "m.type", null, null]] + +An "ev" row is sent for each new event. The fields in the data part are: + + * The new event id + * The room id for the event + * The type of the new event + * The state key of the event, for state events + * The event id of an event which is redacted by this event. + +A "state" row is sent whenever the "current state" in a room changes. The fields in the +data part are: + + * The room id for the state change + * The event type of the state which has changed + * The state_key of the state which has changed + * The event id of the new state + +""" + + +@attr.s(slots=True, frozen=True) +class EventsStreamRow(object): + """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows + data = attr.ib() # BaseEventsStreamRow + + +class BaseEventsStreamRow(object): + """Base class for rows to be sent in the events stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @classmethod + def from_data(cls, data): + """Parse the data from the replication stream into a row. + + By default we just call the constructor with the data list as arguments + + Args: + data: The value of the data object from the replication stream + """ + return cls(*data) + + +@attr.s(slots=True, frozen=True) +class EventsStreamEventRow(BaseEventsStreamRow): + TypeId = "ev" + + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + + +@attr.s(slots=True, frozen=True) +class EventsStreamCurrentStateRow(BaseEventsStreamRow): + TypeId = "state" + + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str + event_id = attr.ib() # str, optional + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, + ) +} + + +class EventsStream(Stream): + """We received a new event, or an event went from being an outlier to not + """ + NAME = "events" + + def __init__(self, hs): + self._store = hs.get_datastore() + self.current_token = self._store.get_current_events_token + + super(EventsStream, self).__init__(hs) + + @defer.inlineCallbacks + def update_function(self, from_token, current_token, limit=None): + event_rows = yield self._store.get_all_new_forward_event_rows( + from_token, current_token, limit, + ) + event_updates = ( + (row[0], EventsStreamEventRow.TypeId, row[1:]) + for row in event_rows + ) + + state_rows = yield self._store.get_all_updated_current_state_deltas( + from_token, current_token, limit + ) + state_updates = ( + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) + for row in state_rows + ) + + all_updates = heapq.merge(event_updates, state_updates) + + defer.returnValue(all_updates) + + @classmethod + def parse_row(cls, row): + (typ, data) = row + data = TypeToRow[typ].from_data(data) + return EventsStreamRow(typ, data) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py new file mode 100644 index 0000000000..9aa43aa8d2 --- /dev/null +++ b/synapse/replication/tcp/streams/federation.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +from ._base import Stream + +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) + + +class FederationStream(Stream): + """Data to be sent over federation. Only available when master has federation + sending disabled. + """ + NAME = "federation" + ROW_TYPE = FederationStreamRow + + def __init__(self, hs): + federation_sender = hs.get_federation_sender() + + self.current_token = federation_sender.get_current_token + self.update_function = federation_sender.get_replication_rows + + super(FederationStream, self).__init__(hs) diff --git a/synapse/res/templates/mail-expiry.css b/synapse/res/templates/mail-expiry.css new file mode 100644 index 0000000000..3dea486467 --- /dev/null +++ b/synapse/res/templates/mail-expiry.css @@ -0,0 +1,4 @@ +.noticetext { + margin-top: 10px; + margin-bottom: 10px; +} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html new file mode 100644 index 0000000000..f0d7c66e1b --- /dev/null +++ b/synapse/res/templates/notice_expiry.html @@ -0,0 +1,43 @@ +<!doctype html> +<html lang="en"> + <head> + <style type="text/css"> + {% include 'mail.css' without context %} + {% include "mail-%s.css" % app_name ignore missing without context %} + {% include 'mail-expiry.css' without context %} + </style> + </head> + <body> + <table id="page"> + <tr> + <td> </td> + <td id="inner"> + <table class="header"> + <tr> + <td> + <div class="salutation">Hi {{ display_name }},</div> + </td> + <td class="logo"> + {% if app_name == "Riot" %} + <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> + {% elif app_name == "Vector" %} + <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% else %} + <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> + {% endif %} + </td> + </tr> + <tr> + <td colspan="2"> + <div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div> + <div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div> + <div class="noticetext"><a href="{{ url }}">{{ url }}</a></div> + </td> + </tr> + </table> + </td> + <td> </td> + </tr> + </table> + </body> +</html> diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt new file mode 100644 index 0000000000..41f1c4279c --- /dev/null +++ b/synapse/res/templates/notice_expiry.txt @@ -0,0 +1,7 @@ +Hi {{ display_name }}, + +Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date. + +To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab): + +{{ url }} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 91f5247d52..3a24d31d1b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -13,11 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import synapse.rest.admin from synapse.http.server import JsonResource from synapse.rest.client import versions from synapse.rest.client.v1 import ( - admin, directory, events, initial_sync, @@ -33,6 +32,7 @@ from synapse.rest.client.v1 import ( from synapse.rest.client.v2_alpha import ( account, account_data, + account_validity, auth, capabilities, devices, @@ -57,8 +57,14 @@ from synapse.rest.client.v2_alpha import ( class ClientRestResource(JsonResource): - """A resource for version 1 of the matrix client API.""" + """Matrix Client API REST resource. + This gets mounted at various points under /_matrix/client, including: + * /_matrix/client/r0 + * /_matrix/client/api/v1 + * /_matrix/client/unstable + * etc + """ def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) @@ -81,7 +87,6 @@ class ClientRestResource(JsonResource): presence.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) - admin.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) logout.register_servlets(hs, client_resource) @@ -109,3 +114,9 @@ class ClientRestResource(JsonResource): groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) + account_validity.register_servlets(hs, client_resource) + + # moving to /_synapse/admin + synapse.rest.admin.register_servlets_for_client_rest_resource( + hs, client_resource + ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/admin/__init__.py index e788769639..0ce89741f0 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import hashlib import hmac import logging import platform +import re from six import text_type from six.moves import http_client @@ -27,39 +28,56 @@ from twisted.internet import defer import synapse from synapse.api.constants import Membership, UserTypes from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.http.server import JsonResource from synapse.http.servlet import ( + RestServlet, assert_params_in_dict, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin +from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.types import UserID, create_requester from synapse.util.versionstring import get_version_string -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class UsersRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)") +def historical_admin_path_patterns(path_regex): + """Returns the list of patterns for an admin endpoint, including historical ones + + This is a backwards-compatibility hack. Previously, the Admin API was exposed at + various paths under /_matrix/client. This function returns a list of patterns + matching those paths (as well as the new one), so that existing scripts which rely + on the endpoints being available there are not broken. + + Note that this should only be used for existing endpoints: new ones should just + register for the /_synapse/admin path. + """ + return list( + re.compile(prefix + path_regex) + for prefix in ( + "^/_synapse/admin/v1", + "^/_matrix/client/api/v1/admin", + "^/_matrix/client/unstable/admin", + "^/_matrix/client/r0/admin" + ) + ) + + +class UsersRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)") def __init__(self, hs): - super(UsersRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") - - # To allow all users to get the users list - # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") @@ -69,16 +87,15 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class VersionServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/server_version") +class VersionServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/server_version") + + def __init__(self, hs): + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) ret = { 'server_version': get_version_string(synapse), @@ -88,18 +105,17 @@ class VersionServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class UserRegisterServlet(ClientV1RestServlet): +class UserRegisterServlet(RestServlet): """ Attributes: NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted nonces (dict[str, int]): The nonces that we will accept. A dict of nonce to the time it was generated, in int seconds. """ - PATTERNS = client_path_patterns("/admin/register") + PATTERNS = historical_admin_path_patterns("/register") NONCE_TIMEOUT = 60 def __init__(self, hs): - super(UserRegisterServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.reactor = hs.get_reactor() self.nonces = {} @@ -226,11 +242,12 @@ class UserRegisterServlet(ClientV1RestServlet): defer.returnValue((200, result)) -class WhoisRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") +class WhoisRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)") def __init__(self, hs): - super(WhoisRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -238,10 +255,9 @@ class WhoisRestServlet(ClientV1RestServlet): target_user = UserID.from_string(user_id) requester = yield self.auth.get_user_by_req(request) auth_user = requester.user - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin and target_user != auth_user: - raise AuthError(403, "You are not a server admin") + if target_user != auth_user: + yield assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only whois a local user") @@ -251,20 +267,16 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class PurgeMediaCacheRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/purge_media_cache") +class PurgeMediaCacheRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/purge_media_cache") def __init__(self, hs): self.media_repository = hs.get_media_repository() - super(PurgeMediaCacheRestServlet, self).__init__(hs) + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) logger.info("before_ts: %r", before_ts) @@ -274,9 +286,9 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class PurgeHistoryRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" +class PurgeHistoryRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns( + "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" ) def __init__(self, hs): @@ -285,17 +297,13 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): Args: hs (synapse.server.HomeServer) """ - super(PurgeHistoryRestServlet, self).__init__(hs) self.pagination_handler = hs.get_pagination_handler() self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) @@ -371,9 +379,9 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): })) -class PurgeHistoryStatusRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/admin/purge_history_status/(?P<purge_id>[^/]+)" +class PurgeHistoryStatusRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns( + "/purge_history_status/(?P<purge_id>[^/]+)" ) def __init__(self, hs): @@ -382,16 +390,12 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet): Args: hs (synapse.server.HomeServer) """ - super(PurgeHistoryStatusRestServlet, self).__init__(hs) self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, purge_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) purge_status = self.pagination_handler.get_purge_status(purge_id) if purge_status is None: @@ -400,15 +404,16 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet): defer.returnValue((200, purge_status.asdict())) -class DeactivateAccountRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__(hs) self._deactivate_account_handler = hs.get_deactivate_account_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, target_user_id): + yield assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -419,11 +424,6 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): ) UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") result = yield self._deactivate_account_handler.deactivate_account( target_user_id, erase, @@ -438,13 +438,13 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): })) -class ShutdownRoomRestServlet(ClientV1RestServlet): +class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed to a new room created by `new_room_user_id` and kicked users will be auto joined to the new room. """ - PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)") + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" @@ -452,19 +452,18 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): ) def __init__(self, hs): - super(ShutdownRoomRestServlet, self).__init__(hs) + self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() self._room_creation_handler = hs.get_room_creation_handler() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id): requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_user_is_admin(self.auth, requester.user) content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) @@ -499,7 +498,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): # desirable in case the first attempt at blocking the room failed below. yield self.store.block_room(room_id, requester_user_id) - users = yield self.state.get_current_user_in_room(room_id) + users = yield self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] for user_id in users: @@ -564,22 +563,20 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): })) -class QuarantineMediaInRoom(ClientV1RestServlet): +class QuarantineMediaInRoom(RestServlet): """Quarantines all media in a room so that no one can download it via this server. """ - PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)") + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") def __init__(self, hs): - super(QuarantineMediaInRoom, self).__init__(hs) self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id): requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_user_is_admin(self.auth, requester.user) num_quarantined = yield self.store.quarantine_media_ids_in_room( room_id, requester.user.to_string(), @@ -588,13 +585,12 @@ class QuarantineMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"num_quarantined": num_quarantined})) -class ListMediaInRoom(ClientV1RestServlet): +class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ - PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media") + PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media") def __init__(self, hs): - super(ListMediaInRoom, self).__init__(hs) self.store = hs.get_datastore() @defer.inlineCallbacks @@ -609,11 +605,11 @@ class ListMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) -class ResetPasswordRestServlet(ClientV1RestServlet): +class ResetPasswordRestServlet(RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ + http://localhost:8008/_synapse/admin/v1/reset_password/ @user:to_reset_password?access_token=admin_access_token JsonBodyToSend: { @@ -622,11 +618,10 @@ class ResetPasswordRestServlet(ClientV1RestServlet): Returns: 200 OK with empty object if success otherwise an error. """ - PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)") + PATTERNS = historical_admin_path_patterns("/reset_password/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self._set_password_handler = hs.get_set_password_handler() @@ -636,39 +631,34 @@ class ResetPasswordRestServlet(ClientV1RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. """ - UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + yield assert_user_is_admin(self.auth, requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + UserID.from_string(target_user_id) params = parse_json_object_from_request(request) assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] - logger.info("new_password: %r", new_password) - yield self._set_password_handler.set_password( target_user_id, new_password, requester ) defer.returnValue((200, {})) -class GetUsersPaginatedRestServlet(ClientV1RestServlet): +class GetUsersPaginatedRestServlet(RestServlet): """Get request to get specific number of users from Synapse. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + http://localhost:8008/_synapse/admin/v1/users_paginate/ @admin:user?access_token=admin_access_token&start=0&limit=10 Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)") + PATTERNS = historical_admin_path_patterns("/users_paginate/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(GetUsersPaginatedRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self.handlers = hs.get_handlers() @@ -678,16 +668,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Get request to get specific number of users from Synapse. This needs user to have administrator access in Synapse. """ - target_user = UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + yield assert_requester_is_admin(self.auth, request) - if not is_admin: - raise AuthError(403, "You are not a server admin") - - # To allow all users to get the users list - # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + target_user = UserID.from_string(target_user_id) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") @@ -708,7 +691,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Post request to get specific number of users from Synapse.. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + http://localhost:8008/_synapse/admin/v1/users_paginate/ @admin:user?access_token=admin_access_token JsonBodyToSend: { @@ -718,12 +701,8 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ + yield assert_requester_is_admin(self.auth, request) UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") order = "name" # order by name in user table params = parse_json_object_from_request(request) @@ -738,21 +717,20 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class SearchUsersRestServlet(ClientV1RestServlet): +class SearchUsersRestServlet(RestServlet): """Get request to search user table for specific users according to search term. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/search_users/ + http://localhost:8008/_synapse/admin/v1/search_users/ @admin:user?access_token=admin_access_token&term=alice Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(SearchUsersRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self.handlers = hs.get_handlers() @@ -763,12 +741,9 @@ class SearchUsersRestServlet(ClientV1RestServlet): search term. This needs user to have a administrator access in Synapse. """ - target_user = UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + yield assert_requester_is_admin(self.auth, request) - if not is_admin: - raise AuthError(403, "You are not a server admin") + target_user = UserID.from_string(target_user_id) # To allow all users to get the users list # if not is_admin and target_user != auth_user: @@ -786,7 +761,79 @@ class SearchUsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -def register_servlets(hs, http_server): +class DeleteGroupAdminRestServlet(RestServlet): + """Allows deleting of local groups + """ + PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)") + + def __init__(self, hs): + self.group_server = hs.get_groups_server_handler() + self.is_mine_id = hs.is_mine_id + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + yield assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine_id(group_id): + raise SynapseError(400, "Can only delete local groups") + + yield self.group_server.delete_group(group_id, requester.user.to_string()) + defer.returnValue((200, {})) + + +class AccountValidityRenewServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/account_validity/validity$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_POST(self, request): + yield assert_requester_is_admin(self.auth, request) + + body = parse_json_object_from_request(request) + + if "user_id" not in body: + raise SynapseError(400, "Missing property 'user_id' in the request body") + + expiration_ts = yield self.account_activity_handler.renew_account_for_user( + body["user_id"], body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) + + res = { + "expiration_ts": expiration_ts, + } + defer.returnValue((200, res)) + +######################################################################################## +# +# please don't add more servlets here: this file is already long and unwieldy. Put +# them in separate files within the 'admin' package. +# +######################################################################################## + + +class AdminRestResource(JsonResource): + """The REST resource which gets mounted at /_synapse/admin""" + + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + + register_servlets_for_client_rest_resource(hs, self) + SendServerNoticeServlet(hs).register(self) + + +def register_servlets_for_client_rest_resource(hs, http_server): + """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) @@ -801,3 +848,7 @@ def register_servlets(hs, http_server): ListMediaInRoom(hs).register(http_server) UserRegisterServlet(hs).register(http_server) VersionServlet(hs).register(http_server) + DeleteGroupAdminRestServlet(hs).register(http_server) + AccountValidityRenewServlet(hs).register(http_server) + # don't add more things here: new servlets should only be exposed on + # /_synapse/admin so should not go here. Instead register them in AdminRestResource. diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py new file mode 100644 index 0000000000..881d67b89c --- /dev/null +++ b/synapse/rest/admin/_base.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import AuthError + + +@defer.inlineCallbacks +def assert_requester_is_admin(auth, request): + """Verify that the requester is an admin user + + WARNING: MAKE SURE YOU YIELD ON THE RESULT! + + Args: + auth (synapse.api.auth.Auth): + request (twisted.web.server.Request): incoming request + + Returns: + Deferred + + Raises: + AuthError if the requester is not an admin + """ + requester = yield auth.get_user_by_req(request) + yield assert_user_is_admin(auth, requester.user) + + +@defer.inlineCallbacks +def assert_user_is_admin(auth, user_id): + """Verify that the given user is an admin user + + WARNING: MAKE SURE YOU YIELD ON THE RESULT! + + Args: + auth (synapse.api.auth.Auth): + user_id (UserID): + + Returns: + Deferred + + Raises: + AuthError if the user is not an admin + """ + + is_admin = yield auth.is_server_admin(user_id) + if not is_admin: + raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py new file mode 100644 index 0000000000..ae5aca9dac --- /dev/null +++ b/synapse/rest/admin/server_notice_servlet.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import UserID + + +class SendServerNoticeServlet(RestServlet): + """Servlet which will send a server notice to a given user + + POST /_synapse/admin/v1/send_server_notice + { + "user_id": "@target_user:server_name", + "content": { + "msgtype": "m.text", + "body": "This is my message" + } + } + + returns: + + { + "event_id": "$1895723857jgskldgujpious" + } + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + self.snm = hs.get_server_notices_manager() + + def register(self, json_resource): + PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths( + "POST", + (re.compile(PATTERN + "$"), ), + self.on_POST, + ) + json_resource.register_paths( + "PUT", + (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$",), ), + self.on_PUT, + ) + + @defer.inlineCallbacks + def on_POST(self, request, txn_id=None): + yield assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ("user_id", "content")) + event_type = body.get("type", EventTypes.Message) + state_key = body.get("state_key") + + if not self.snm.is_enabled(): + raise SynapseError(400, "Server notices are not enabled on this server") + + user_id = body["user_id"] + UserID.from_string(user_id) + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Server notices can only be sent to local users") + + event = yield self.snm.send_notice( + user_id=body["user_id"], + type=event_type, + state_key=state_key, + event_content=body["content"], + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + def on_PUT(self, request, txn_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, txn_id, + ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index b5a6d6aebf..045d5a20ac 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -93,72 +93,5 @@ class PresenceStatusRestServlet(ClientV1RestServlet): return (200, {}) -class PresenceListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)") - - def __init__(self, hs): - super(PresenceListRestServlet, self).__init__(hs) - self.presence_handler = hs.get_presence_handler() - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self.hs.is_mine(user): - raise SynapseError(400, "User not hosted on this Home Server") - - if requester.user != user: - raise SynapseError(400, "Cannot get another user's presence list") - - presence = yield self.presence_handler.get_presence_list( - observer_user=user, accepted=True - ) - - defer.returnValue((200, presence)) - - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self.hs.is_mine(user): - raise SynapseError(400, "User not hosted on this Home Server") - - if requester.user != user: - raise SynapseError( - 400, "Cannot modify another user's presence list") - - content = parse_json_object_from_request(request) - - if "invite" in content: - for u in content["invite"]: - if not isinstance(u, string_types): - raise SynapseError(400, "Bad invite value.") - if len(u) == 0: - continue - invited_user = UserID.from_string(u) - yield self.presence_handler.send_presence_invite( - observer_user=user, observed_user=invited_user - ) - - if "drop" in content: - for u in content["drop"]: - if not isinstance(u, string_types): - raise SynapseError(400, "Bad drop value.") - if len(u) == 0: - continue - dropped_user = UserID.from_string(u) - yield self.presence_handler.drop( - observer_user=user, observed_user=dropped_user - ) - - defer.returnValue((200, {})) - - def on_OPTIONS(self, request): - return (200, {}) - - def register_servlets(hs, http_server): PresenceStatusRestServlet(hs).register(http_server) - PresenceListRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index c654f9b5f0..506ec95ddd 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -31,7 +31,7 @@ from .base import ClientV1RestServlet, client_path_patterns class PushRuleRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/pushrules/.*$") + PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$") SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") @@ -39,10 +39,14 @@ class PushRuleRestServlet(ClientV1RestServlet): super(PushRuleRestServlet, self).__init__(hs) self.store = hs.get_datastore() self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None @defer.inlineCallbacks - def on_PUT(self, request): - spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath]) + def on_PUT(self, request, path): + if self._is_worker: + raise Exception("Cannot handle PUT /push_rules on worker") + + spec = _rule_spec_from_path([x for x in path.split("/")]) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: @@ -102,8 +106,11 @@ class PushRuleRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) @defer.inlineCallbacks - def on_DELETE(self, request): - spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath]) + def on_DELETE(self, request, path): + if self._is_worker: + raise Exception("Cannot handle DELETE /push_rules on worker") + + spec = _rule_spec_from_path([x for x in path.split("/")]) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -123,7 +130,7 @@ class PushRuleRestServlet(ClientV1RestServlet): raise @defer.inlineCallbacks - def on_GET(self, request): + def on_GET(self, request, path): requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -134,7 +141,7 @@ class PushRuleRestServlet(ClientV1RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = [x.decode('utf8') for x in request.postpath][1:] + path = [x for x in path.split("/")][1:] if path == []: # we're a reference impl: pedantry is our job. @@ -150,7 +157,7 @@ class PushRuleRestServlet(ClientV1RestServlet): else: raise UnrecognizedRequestError() - def on_OPTIONS(self, _): + def on_OPTIONS(self, request, path): return 200, {} def notify_user(self, user_id): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 37b32dd37b..ee069179f0 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -215,6 +215,7 @@ class DeactivateAccountRestServlet(RestServlet): ) result = yield self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" @@ -363,7 +364,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=()) + PATTERNS = client_v2_patterns("/account/3pid/delete$") def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() @@ -380,7 +381,7 @@ class ThreepidDeleteRestServlet(RestServlet): try: ret = yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'] + user_id, body['medium'], body['address'], body.get("id_server"), ) except Exception: # NB. This endpoint should succeed if there is nothing to diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py new file mode 100644 index 0000000000..fc8dbeb617 --- /dev/null +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import AuthError, SynapseError +from synapse.http.server import finish_request +from synapse.http.servlet import RestServlet + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class AccountValidityRenewServlet(RestServlet): + PATTERNS = client_v2_patterns("/account_validity/renew$") + SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(AccountValidityRenewServlet, self).__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + if b"token" not in request.args: + raise SynapseError(400, "Missing renewal token") + renewal_token = request.args[b"token"][0] + + yield self.account_activity_handler.renew_account(renewal_token.decode('utf8')) + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % ( + len(AccountValidityRenewServlet.SUCCESS_HTML), + )) + request.write(AccountValidityRenewServlet.SUCCESS_HTML) + finish_request(request) + defer.returnValue(None) + + +class AccountValiditySendMailServlet(RestServlet): + PATTERNS = client_v2_patterns("/account_validity/send_mail$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(AccountValiditySendMailServlet, self).__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + self.account_validity = self.hs.config.account_validity + + @defer.inlineCallbacks + def on_POST(self, request): + if not self.account_validity.renew_by_email_enabled: + raise AuthError(403, "Account renewal via email is disabled on this server.") + + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + yield self.account_activity_handler.send_renewal_email_to_user(user_id) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + AccountValidityRenewServlet(hs).register(http_server) + AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 373f95126e..a868d06098 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.api.constants import DEFAULT_ROOM_VERSION, RoomDisposition, RoomVersions +from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -48,12 +48,10 @@ class CapabilitiesRestServlet(RestServlet): response = { "capabilities": { "m.room_versions": { - "default": DEFAULT_ROOM_VERSION, + "default": DEFAULT_ROOM_VERSION.identifier, "available": { - RoomVersions.V1: RoomDisposition.STABLE, - RoomVersions.V2: RoomDisposition.STABLE, - RoomVersions.STATE_V2_TEST: RoomDisposition.UNSTABLE, - RoomVersions.V3: RoomDisposition.STABLE, + v.identifier: v.disposition + for v in KNOWN_ROOM_VERSIONS.values() }, }, "m.change_password": {"enabled": change_password}, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6d235262c8..dc3e265bcd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -391,6 +391,13 @@ class RegisterRestServlet(RestServlet): # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. + # + # Also check that we're not trying to register a 3pid that's already + # been registered. + # + # This has probably happened in /register/email/requestToken as well, + # but if a user hits this endpoint twice then clicks on each link from + # the two activation emails, they would register the same 3pid twice. if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: @@ -406,6 +413,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + existingUid = yield self.store.get_user_id_by_threepid( + medium, address, + ) + + if existingUid is not None: + raise SynapseError( + 400, + "%s is already in use" % medium, + Codes.THREEPID_IN_USE, + ) + if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index e6356101fd..3db7ff8d1b 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -17,8 +17,8 @@ import logging from twisted.internet import defer -from synapse.api.constants import KNOWN_ROOM_VERSIONS from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 953d89bd82..2dcc8f74d6 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd. +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -191,6 +191,10 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam # in that case. logger.warning("Failed to write to consumer: %s %s", type(e), e) + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + finish_request(request) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c0a4ae93e5..a7fa4f39af 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -68,6 +68,6 @@ class WellKnownResource(Resource): request.setHeader(b"Content-Type", b"text/plain") return b'.well-known not available' - logger.error("returning: %s", r) + logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") return json.dumps(r).encode("utf-8") diff --git a/synapse/server.py b/synapse/server.py index dc8f1ccb8c..8c30ac2fa5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -47,6 +47,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler from synapse.handlers import Handlers +from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator @@ -183,6 +184,7 @@ class HomeServer(object): 'room_context_handler', 'sendmail', 'registration_handler', + 'account_validity_handler', ] REQUIRED_ON_MASTER_STARTUP = [ @@ -506,6 +508,9 @@ class HomeServer(object): def build_registration_handler(self): return RegistrationHandler(self) + def build_account_validity_handler(self): + return AccountValidityHandler(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 68058f613c..36684ef9f6 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -24,7 +24,8 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes, RoomVersions +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events.snapshot import EventContext from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer @@ -160,10 +161,21 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def get_current_user_in_room(self, room_id, latest_event_ids=None): + def get_current_users_in_room(self, room_id, latest_event_ids=None): + """ + Get the users who are currently in a room. + + Args: + room_id (str): The ID of the room. + latest_event_ids (List[str]|None): Precomputed list of latest + event IDs. Will be computed if None. + Returns: + Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their + profileinfo. + """ if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.debug("calling resolve_state_groups from get_current_user_in_room") + logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state(room_id, entry) defer.returnValue(joined_users) @@ -603,22 +615,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if room_version == RoomVersions.V1: + v = KNOWN_ROOM_VERSIONS[room_version] + if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( state_sets, event_map, state_res_store.get_events, ) - elif room_version in ( - RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3, - ): + else: return v2.resolve_events_with_store( room_version, state_sets, event_map, state_res_store, ) - else: - # This should only happen if we added a version but forgot to add it to - # the list above. - raise Exception( - "No state resolution algorithm defined for version %r" % (room_version,) - ) @attr.s diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 6d3afcae7c..29b4e86cfd 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -21,8 +21,9 @@ from six import iteritems, iterkeys, itervalues from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, RoomVersions +from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions logger = logging.getLogger(__name__) @@ -275,7 +276,9 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1, event, auth_events, + RoomVersions.V1.identifier, + event, + auth_events, do_sig_check=False, do_size_check=False, ) @@ -291,7 +294,9 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1, event, auth_events, + RoomVersions.V1.identifier, + event, + auth_events, do_sig_check=False, do_size_check=False, ) diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 3a958749a1..e02663f50e 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -49,7 +49,7 @@ var show_login = function() { $("#loading").hide(); var this_page = window.location.origin + window.location.pathname; - $("#sso_redirect_url").val(encodeURIComponent(this_page)); + $("#sso_redirect_url").val(this_page); if (matrixLogin.serverAcceptsPassword) { $("#password_flow").show(); diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 42cd3c83ad..c432041b4e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,6 +18,8 @@ import calendar import logging import time +from twisted.internet import defer + from synapse.api.constants import PresenceState from synapse.storage.devices import DeviceStore from synapse.storage.user_erasure_store import UserErasureStore @@ -61,48 +63,60 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat logger = logging.getLogger(__name__) -class DataStore(RoomMemberStore, RoomStore, - RegistrationStore, StreamStore, ProfileStore, - PresenceStore, TransactionStore, - DirectoryStore, KeyStore, StateStore, SignatureStore, - ApplicationServiceStore, - EventsStore, - EventFederationStore, - MediaRepositoryStore, - RejectionsStore, - FilteringStore, - PusherStore, - PushRuleStore, - ApplicationServiceTransactionStore, - ReceiptsStore, - EndToEndKeyStore, - EndToEndRoomKeyStore, - SearchStore, - TagsStore, - AccountDataStore, - EventPushActionsStore, - OpenIdStore, - ClientIpStore, - DeviceStore, - DeviceInboxStore, - UserDirectoryStore, - GroupServerStore, - UserErasureStore, - MonthlyActiveUsersStore, - ): - +class DataStore( + RoomMemberStore, + RoomStore, + RegistrationStore, + StreamStore, + ProfileStore, + PresenceStore, + TransactionStore, + DirectoryStore, + KeyStore, + StateStore, + SignatureStore, + ApplicationServiceStore, + EventsStore, + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore, + ApplicationServiceTransactionStore, + ReceiptsStore, + EndToEndKeyStore, + EndToEndRoomKeyStore, + SearchStore, + TagsStore, + AccountDataStore, + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, + UserErasureStore, + MonthlyActiveUsersStore, +): def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", - extra_tables=[("local_invites", "stream_id")] + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], ) self._backfill_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")] + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" @@ -114,7 +128,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "public_room_list_stream", "stream_id" ) self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id", + db_conn, "device_lists_stream", "stream_id" ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") @@ -125,16 +139,15 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" ) self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", - extra_tables=[("deleted_pushers", "stream_id")], + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id", + db_conn, "local_group_updates", "stream_id" ) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id" ) else: self._cache_id_gen = None @@ -142,72 +155,82 @@ class DataStore(RoomMemberStore, RoomStore, self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( - db_conn, "presence_stream", + db_conn, + "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", min_presence_val, - prefilled_cache=presence_cache_prefill + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( - db_conn, "device_inbox", + db_conn, + "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", min_device_inbox_id, + "DeviceInboxStreamChangeCache", + min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( - db_conn, "device_federation_outbox", + db_conn, + "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max, + "DeviceListStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max, + "DeviceListFederationStreamChangeCache", device_list_max ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, "current_state_delta_stream", + db_conn, + "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", min_curr_state_delta_id, + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) _group_updates_prefill, min_group_updates_id = self._get_cache_dict( - db_conn, "local_group_updates", + db_conn, + "local_group_updates", entity_column="user_id", stream_column="stream_id", max_value=self._group_updates_id_gen.get_current_token(), limit=1000, ) self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", min_group_updates_id, + "_group_updates_stream_cache", + min_group_updates_id, prefilled_cache=_group_updates_prefill, ) @@ -250,6 +273,7 @@ class DataStore(RoomMemberStore, RoomStore, """ Counts the number of users who used this homeserver in the last 24 hours. """ + def _count_users(txn): yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) @@ -277,6 +301,7 @@ class DataStore(RoomMemberStore, RoomStore, Returns counts globaly for a given user as well as breaking by platform """ + def _count_r30_users(txn): thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) @@ -313,8 +338,7 @@ class DataStore(RoomMemberStore, RoomStore, """ results = {} - txn.execute(sql, (thirty_days_ago_in_secs, - thirty_days_ago_in_secs)) + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) for row in txn: if row[0] == 'unknown': @@ -341,8 +365,7 @@ class DataStore(RoomMemberStore, RoomStore, ) u """ - txn.execute(sql, (thirty_days_ago_in_secs, - thirty_days_ago_in_secs)) + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) count, = txn.fetchone() results['all'] = count @@ -356,15 +379,14 @@ class DataStore(RoomMemberStore, RoomStore, Returns millisecond unixtime for start of UTC day. """ now = time.gmtime() - today_start = calendar.timegm(( - now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0, - )) + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 def generate_user_daily_visits(self): """ Generates daily visit data for use in cohort/ retention analysis """ + def _generate_user_daily_visits(txn): logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() @@ -395,25 +417,29 @@ class DataStore(RoomMemberStore, RoomStore, # often to minimise this case. if today_start > self._last_user_visit_update: yesterday_start = today_start - a_day_in_milliseconds - txn.execute(sql, ( - yesterday_start, yesterday_start, - self._last_user_visit_update, today_start - )) + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) self._last_user_visit_update = today_start - txn.execute(sql, ( - today_start, today_start, - self._last_user_visit_update, - now - )) + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) # Update _last_user_visit_update to now. The reason to do this # rather just clamping to the beginning of the day is to limit # the size of the join - meaning that the query can be run more # frequently self._last_user_visit_update = now - return self.runInteraction("generate_user_daily_visits", - _generate_user_daily_visits) + return self.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) def get_users(self): """Function to reterive a list of users in users table. @@ -425,15 +451,11 @@ class DataStore(RoomMemberStore, RoomStore, return self._simple_select_list( table="users", keyvalues={}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin" - ], + retcols=["name", "password_hash", "is_guest", "admin"], desc="get_users", ) + @defer.inlineCallbacks def get_users_paginate(self, order, start, limit): """Function to reterive a paginated list of users from users list. This will return a json object, which contains @@ -446,27 +468,19 @@ class DataStore(RoomMemberStore, RoomStore, Returns: defer.Deferred: resolves to json object {list[dict[str, Any]], count} """ - is_guest = 0 - i_start = (int)(start) - i_limit = (int)(limit) - return self.get_user_list_paginate( + users = yield self.runInteraction( + "get_users_paginate", + self._simple_select_list_paginate_txn, table="users", - keyvalues={ - "is_guest": is_guest - }, - pagevalues=[ - order, - i_limit, - i_start - ], - retcols=[ - "name", - "password_hash", - "is_guest", - "admin" - ], - desc="get_users_paginate", + keyvalues={"is_guest": False}, + orderby=order, + start=start, + limit=limit, + retcols=["name", "password_hash", "is_guest", "admin"], ) + count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + retval = {"users": users, "total": count} + defer.returnValue(retval) def search_users(self, term): """Function to search users list for one or more users with @@ -482,12 +496,7 @@ class DataStore(RoomMemberStore, RoomStore, table="users", term=term, col="name", - retcols=[ - "name", - "password_hash", - "is_guest", - "admin" - ], + retcols=["name", "password_hash", "is_guest", "admin"], desc="search_users", ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7e3903859b..983ce026e1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -41,7 +41,7 @@ try: MAX_TXN_ID = sys.maxint - 1 except AttributeError: # python 3 does not have a maximum int value - MAX_TXN_ID = 2**63 - 1 + MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -76,12 +76,18 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" + __slots__ = [ - "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", ] - def __init__(self, txn, name, database_engine, after_callbacks, - exception_callbacks): + def __init__( + self, txn, name, database_engine, after_callbacks, exception_callbacks + ): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) @@ -110,6 +116,7 @@ class LoggingTransaction(object): def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: @@ -134,10 +141,7 @@ class LoggingTransaction(object): sql = self.database_engine.convert_param_style(sql) if args: try: - sql_logger.debug( - "[SQL values] {%s} %r", - self.name, args[0] - ) + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -145,9 +149,7 @@ class LoggingTransaction(object): start = time.time() try: - return func( - sql, *args - ) + return func(sql, *args) except Exception as e: logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise @@ -176,11 +178,9 @@ class PerformanceCounters(object): counters = [] for name, (count, cum_time) in iteritems(self.current_counters): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append(( - (cum_time - prev_time) / interval_duration, - count - prev_count, - name - )) + counters.append( + ((cum_time - prev_time) / interval_duration, count - prev_count, name) + ) self.previous_counters = dict(self.current_counters) @@ -212,8 +212,9 @@ class SQLBaseStore(object): self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() - self._get_event_cache = Cache("*getEvent*", keylen=3, - max_entries=hs.config.event_cache_size) + self._get_event_cache = Cache( + "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + ) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] @@ -239,7 +240,7 @@ class SQLBaseStore(object): 0.0, run_as_background_process, "upsert_safety_check", - self._check_safe_to_upsert + self._check_safe_to_upsert, ) @defer.inlineCallbacks @@ -271,7 +272,7 @@ class SQLBaseStore(object): 15.0, run_as_background_process, "upsert_safety_check", - self._check_safe_to_upsert + self._check_safe_to_upsert, ) def start_profiling(self): @@ -298,13 +299,16 @@ class SQLBaseStore(object): perf_logger.info( "Total database time: %.3f%% {%s} {%s}", - ratio * 100, top_three_counters, top_3_event_counters + ratio * 100, + top_three_counters, + top_3_event_counters, ) self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, - func, *args, **kwargs): + def _new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): start = time.time() txn_id = self._TXN_ID @@ -312,7 +316,7 @@ class SQLBaseStore(object): # growing really large. self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - name = "%s-%x" % (desc, txn_id, ) + name = "%s-%x" % (desc, txn_id) transaction_logger.debug("[TXN START] {%s}", name) @@ -323,7 +327,10 @@ class SQLBaseStore(object): try: txn = conn.cursor() txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks, + txn, + name, + self.database_engine, + after_callbacks, exception_callbacks, ) r = func(txn, *args, **kwargs) @@ -334,7 +341,10 @@ class SQLBaseStore(object): # transaction. logger.warning( "[TXN OPERROR] {%s} %s %d/%d", - name, exception_to_unicode(e), i, N + name, + exception_to_unicode(e), + i, + N, ) if i < N: i += 1 @@ -342,8 +352,7 @@ class SQLBaseStore(object): conn.rollback() except self.database_engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) ) continue raise @@ -357,7 +366,8 @@ class SQLBaseStore(object): except self.database_engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", - name, exception_to_unicode(e1), + name, + exception_to_unicode(e1), ) continue raise @@ -396,16 +406,17 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn( - "Starting db txn '%s' from sentinel context", - desc, - ) + logger.warn("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( self._new_transaction, - desc, after_callbacks, exception_callbacks, func, - *args, **kwargs + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs ) for after_callback, after_args, after_kwargs in after_callbacks: @@ -434,7 +445,7 @@ class SQLBaseStore(object): parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: logger.warn( - "Starting db connection from sentinel context: metrics will be lost", + "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -453,9 +464,7 @@ class SQLBaseStore(object): return func(conn, *args, **kwargs) with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) defer.returnValue(result) @@ -469,9 +478,7 @@ class SQLBaseStore(object): A list of dicts where the key is the column header. """ col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list( - dict(zip(col_headers, row)) for row in cursor - ) + results = list(dict(zip(col_headers, row)) for row in cursor) return results def _execute(self, desc, decoder, query, *args): @@ -485,6 +492,7 @@ class SQLBaseStore(object): Returns: The result of decoder(results) """ + def interaction(txn): txn.execute(query, args) if decoder: @@ -498,8 +506,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, - desc="_simple_insert"): + def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): """Executes an INSERT query on the named table. Args: @@ -511,10 +518,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction( - desc, - self._simple_insert_txn, table, values, - ) + yield self.runInteraction(desc, self._simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -530,15 +534,13 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys), - ", ".join("?" for _ in keys) + ", ".join("?" for _ in keys), ) txn.execute(sql, vals) def _simple_insert_many(self, table, values, desc): - return self.runInteraction( - desc, self._simple_insert_many_txn, table, values - ) + return self.runInteraction(desc, self._simple_insert_many_txn, table, values) @staticmethod def _simple_insert_many_txn(txn, table, values): @@ -553,24 +555,18 @@ class SQLBaseStore(object): # # The sort is to ensure that we don't rely on dictionary iteration # order. - keys, vals = zip(*[ - zip( - *(sorted(i.items(), key=lambda kv: kv[0])) - ) - for i in values - if i - ]) + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) for k in keys: if k != keys[0]: - raise RuntimeError( - "All items must have the same keys" - ) + raise RuntimeError("All items must have the same keys") sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]) + ", ".join("?" for _ in keys[0]), ) txn.executemany(sql, vals) @@ -583,7 +579,7 @@ class SQLBaseStore(object): values, insertion_values={}, desc="_simple_upsert", - lock=True + lock=True, ): """ @@ -599,7 +595,7 @@ class SQLBaseStore(object): Args: table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values + keyvalues (dict): The unique key columns and their new values values (dict): The nonunique columns and their new values insertion_values (dict): additional key/values to use only when inserting @@ -631,17 +627,11 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "%s when upserting into %s; retrying: %s", e.__name__, table, e + "IntegrityError when upserting into %s; retrying: %s", table, e ) def _simple_upsert_txn( - self, - txn, - table, - keyvalues, - values, - insertion_values={}, - lock=True, + self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ Pick the UPSERT method which works best on the platform. Either the @@ -665,11 +655,7 @@ class SQLBaseStore(object): and table not in self._unsafe_to_upsert_tables ): return self._simple_upsert_txn_native_upsert( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, + txn, table, keyvalues, values, insertion_values=insertion_values ) else: return self._simple_upsert_txn_emulated( @@ -714,7 +700,7 @@ class SQLBaseStore(object): # SELECT instead to see if it exists. sql = "SELECT 1 FROM %s WHERE %s" % ( table, - " AND ".join(_getwhere(k) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues), ) sqlargs = list(keyvalues.values()) txn.execute(sql, sqlargs) @@ -726,7 +712,7 @@ class SQLBaseStore(object): sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues), ) sqlargs = list(values.values()) + list(keyvalues.values()) @@ -773,19 +759,14 @@ class SQLBaseStore(object): latter = "NOTHING" else: allvalues.update(values) - latter = ( - "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - ) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ( - "INSERT INTO %s (%s) VALUES (%s) " - "ON CONFLICT (%s) DO %s" - ) % ( + sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), - latter + latter, ) txn.execute(sql, list(allvalues.values())) @@ -870,8 +851,8 @@ class SQLBaseStore(object): latter = "NOTHING" value_values = [() for x in range(len(key_values))] else: - latter = ( - "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names) + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names ) sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( @@ -889,8 +870,9 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one(self, table, keyvalues, retcols, - allow_none=False, desc="_simple_select_one"): + def _simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -903,14 +885,17 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, - self._simple_select_one_txn, - table, keyvalues, retcols, allow_none, + desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol(self, table, keyvalues, retcol, - allow_none=False, - desc="_simple_select_one_onecol"): + def _simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="_simple_select_one_onecol", + ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -922,17 +907,18 @@ class SQLBaseStore(object): return self.runInteraction( desc, self._simple_select_one_onecol_txn, - table, keyvalues, retcol, allow_none=allow_none, + table, + keyvalues, + retcol, + allow_none=allow_none, ) @classmethod - def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol, - allow_none=False): + def _simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): ret = cls._simple_select_onecol_txn( - txn, - table=table, - keyvalues=keyvalues, - retcol=retcol, + txn, table=table, keyvalues=keyvalues, retcol=retcol ) if ret: @@ -945,12 +931,7 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ( - "SELECT %(retcol)s FROM %(table)s" - ) % { - "retcol": retcol, - "table": table, - } + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) @@ -960,8 +941,9 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol(self, table, keyvalues, retcol, - desc="_simple_select_onecol"): + def _simple_select_onecol( + self, table, keyvalues, retcol, desc="_simple_select_onecol" + ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -974,13 +956,12 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, - self._simple_select_onecol_txn, - table, keyvalues, retcol + desc, self._simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list(self, table, keyvalues, retcols, - desc="_simple_select_list"): + def _simple_select_list( + self, table, keyvalues, retcols, desc="_simple_select_list" + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -994,9 +975,7 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, - self._simple_select_list_txn, - table, keyvalues, retcols + desc, self._simple_select_list_txn, table, keyvalues, retcols ) @classmethod @@ -1016,22 +995,26 @@ class SQLBaseStore(object): sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(sql, list(keyvalues.values())) else: - sql = "SELECT %s FROM %s" % ( - ", ".join(retcols), - table - ) + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) txn.execute(sql) return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch(self, table, column, iterable, retcols, - keyvalues={}, desc="_simple_select_many_batch", - batch_size=100): + def _simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="_simple_select_many_batch", + batch_size=100, + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1053,14 +1036,17 @@ class SQLBaseStore(object): it_list = list(iterable) chunks = [ - it_list[i:i + batch_size] - for i in range(0, len(it_list), batch_size) + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: rows = yield self.runInteraction( desc, self._simple_select_many_txn, - table, column, chunk, keyvalues, retcols + table, + column, + chunk, + keyvalues, + retcols, ) results.extend(rows) @@ -1089,9 +1075,7 @@ class SQLBaseStore(object): clauses = [] values = [] - clauses.append( - "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) - ) + clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) values.extend(iterable) for key, value in iteritems(keyvalues): @@ -1099,19 +1083,14 @@ class SQLBaseStore(object): values.append(value) if clauses: - sql = "%s WHERE %s" % ( - sql, - " AND ".join(clauses), - ) + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) txn.execute(sql, values) return cls.cursor_to_dict(txn) def _simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, - self._simple_update_txn, - table, keyvalues, updatevalues, + desc, self._simple_update_txn, table, keyvalues, updatevalues ) @staticmethod @@ -1127,15 +1106,13 @@ class SQLBaseStore(object): where, ) - txn.execute( - update_sql, - list(updatevalues.values()) + list(keyvalues.values()) - ) + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) return txn.rowcount - def _simple_update_one(self, table, keyvalues, updatevalues, - desc="_simple_update_one"): + def _simple_update_one( + self, table, keyvalues, updatevalues, desc="_simple_update_one" + ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1154,9 +1131,7 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, - self._simple_update_one_txn, - table, keyvalues, updatevalues, + desc, self._simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod @@ -1169,12 +1144,11 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, - allow_none=False): + def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k,) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(select_sql, list(keyvalues.values())) @@ -1197,9 +1171,7 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction( - desc, self._simple_delete_one_txn, table, keyvalues - ) + return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) @staticmethod def _simple_delete_one_txn(txn, table, keyvalues): @@ -1212,7 +1184,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(sql, list(keyvalues.values())) @@ -1222,15 +1194,13 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction( - desc, self._simple_delete_txn, table, keyvalues - ) + return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) @staticmethod def _simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) return txn.execute(sql, list(keyvalues.values())) @@ -1260,9 +1230,7 @@ class SQLBaseStore(object): clauses = [] values = [] - clauses.append( - "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) - ) + clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) values.extend(iterable) for key, value in iteritems(keyvalues): @@ -1270,14 +1238,12 @@ class SQLBaseStore(object): values.append(value) if clauses: - sql = "%s WHERE %s" % ( - sql, - " AND ".join(clauses), - ) + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) return txn.execute(sql, values) - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, - max_value, limit=100000): + def _get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. @@ -1297,10 +1263,7 @@ class SQLBaseStore(object): txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) - cache = { - row[0]: int(row[1]) - for row in txn - } + cache = {row[0]: int(row[1]) for row in txn} txn.close() @@ -1342,9 +1305,7 @@ class SQLBaseStore(object): # be safe. for chunk in batch_iter(members_changed, 50): keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys, - ) + self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys) def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does @@ -1355,28 +1316,13 @@ class SQLBaseStore(object): members_changed (iterable[str]): The user_ids of members that have changed """ - for member in members_changed: - self._attempt_to_invalidate_cache( - "get_rooms_for_user_with_stream_ordering", (member,), - ) - for host in set(get_domain_from_id(u) for u in members_changed): - self._attempt_to_invalidate_cache( - "is_host_joined", (room_id, host,), - ) - self._attempt_to_invalidate_cache( - "was_host_joined", (room_id, host,), - ) + self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) + self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache( - "get_users_in_room", (room_id,), - ) - self._attempt_to_invalidate_cache( - "get_room_summary", (room_id,), - ) - self._attempt_to_invalidate_cache( - "get_current_state_ids", (room_id,), - ) + self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) + self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache(self, cache_name, key): """Attempts to invalidate the cache of the given name, ignoring if the @@ -1424,7 +1370,7 @@ class SQLBaseStore(object): "cache_func": cache_name, "keys": list(keys), "invalidation_ts": self.clock.time_msec(), - } + }, ) def get_all_updated_caches(self, last_id, current_id, limit): @@ -1440,11 +1386,10 @@ class SQLBaseStore(object): " FROM cache_invalidation_stream" " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_id, limit,)) + txn.execute(sql, (last_id, limit)) return txn.fetchall() - return self.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) + + return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) def get_cache_stream_token(self): if self._cache_id_gen: @@ -1452,33 +1397,61 @@ class SQLBaseStore(object): else: return 0 - def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, - desc="_simple_select_list_paginate"): - """Executes a SELECT query on the named table with start and limit, + def _simple_select_list_paginate( + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="_simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. Args: table (str): the table name - keyvalues (dict[str, Any] | None): + keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( desc, self._simple_select_list_paginate_txn, - table, keyvalues, pagevalues, retcols + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction=order_direction, ) @classmethod - def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): - """Executes a SELECT query on the named table with start and limit, + def _simple_select_list_paginate_txn( + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. @@ -1488,66 +1461,32 @@ class SQLBaseStore(object): keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] - """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + if keyvalues: - sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - " ? ASC LIMIT ? OFFSET ?" - ) - txn.execute(sql, list(keyvalues.values()) + list(pagevalues)) + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) else: - sql = "SELECT %s FROM %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " ? ASC LIMIT ? OFFSET ?" - ) - txn.execute(sql, pagevalues) - - return cls.cursor_to_dict(txn) + where_clause = "" - @defer.inlineCallbacks - def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, - desc="get_user_list_paginate"): - """Get a list of users from start row to a limit number of rows. This will - return a json object with users and total number of users in users list. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, keyvalues, pagevalues, retcols - ) - count = yield self.runInteraction( - desc, - self.get_user_count_txn + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, ) - retval = { - "users": users, - "total": count - } - defer.returnValue(retval) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) + + return cls.cursor_to_dict(txn) def get_user_count_txn(self, txn): """Get a total number of registered users in the users list. @@ -1561,8 +1500,9 @@ class SQLBaseStore(object): txn.execute(sql_count) return txn.fetchone()[0] - def _simple_search_list(self, table, term, col, retcols, - desc="_simple_search_list"): + def _simple_search_list( + self, table, term, col, retcols, desc="_simple_search_list" + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1577,9 +1517,7 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, - self._simple_search_list_txn, - table, term, col, retcols + desc, self._simple_search_list_txn, table, term, col, retcols ) @classmethod @@ -1598,11 +1536,7 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] or None """ if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( - ", ".join(retcols), - table, - col - ) + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) termvalues = ["%%" + term + "%%"] txn.execute(sql, termvalues) else: @@ -1623,6 +1557,7 @@ class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying something went wrong. """ + pass diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index bbc3355c73..8394389073 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, + "AccountDataAndTagsChangeCache", account_max ) super(AccountDataWorkerStore, self).__init__(db_conn, hs) @@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore): def get_account_data_for_user_txn(txn): rows = self._simple_select_list_txn( - txn, "account_data", {"user_id": user_id}, - ["account_data_type", "content"] + txn, + "account_data", + {"user_id": user_id}, + ["account_data_type", "content"], ) global_account_data = { @@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore): } rows = self._simple_select_list_txn( - txn, "room_account_data", {"user_id": user_id}, - ["room_id", "account_data_type", "content"] + txn, + "room_account_data", + {"user_id": user_id}, + ["room_id", "account_data_type", "content"], ) by_room = {} @@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ result = yield self._simple_select_one_onecol( table="account_data", - keyvalues={ - "user_id": user_id, - "account_data_type": data_type, - }, + keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, @@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: A deferred dict of the room account_data """ + def get_account_data_for_room_txn(txn): rows = self._simple_select_list_txn( - txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, - ["account_data_type", "content"] + txn, + "room_account_data", + {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"], ) return { @@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore): A deferred of the room account_data for that type, or None if there isn't any set. """ + def get_account_data_for_room_and_type_txn(txn): content_json = self._simple_select_one_onecol_txn( txn, @@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore): "account_data_type": account_data_type, }, retcol="content", - allow_none=True + allow_none=True, ) return json.loads(content_json) if content_json else None return self.runInteraction( - "get_account_data_for_room_and_type", - get_account_data_for_room_and_type_txn, + "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data(self, last_global_id, last_room_id, - current_id, limit): + def get_all_updated_account_data( + self, last_global_id, last_room_id, current_id, limit + ): """Get all the client account_data that has changed on the server Args: last_global_id(int): The position to fetch from for top level data @@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() return (global_results, room_results) + return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) @@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - global_account_data = { - row[0]: json.loads(row[1]) for row in txn - } + global_account_data = {row[0]: json.loads(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" @@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): ignored_account_data = yield self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", ignorer_user_id, + "m.ignored_user_list", + ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: @@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore): "room_id": room_id, "account_data_type": account_data_type, }, - values={ - "stream_id": next_id, - "content": content_json, - }, + values={"stream_id": next_id, "content": content_json}, lock=False, ) @@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore): self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type,), content, + (user_id, room_id, account_data_type), content ) result = self._account_data_id_gen.get_current_token() @@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore): yield self._simple_upsert( desc="add_user_account_data", table="account_data", - keyvalues={ - "user_id": user_id, - "account_data_type": account_data_type, - }, - values={ - "stream_id": next_id, - "content": content_json, - }, + keyvalues={"user_id": user_id, "account_data_type": account_data_type}, + values={"stream_id": next_id, "content": content_json}, lock=False, ) @@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore): # transaction. yield self._update_max_stream_id(next_id) - self._account_data_stream_cache.entity_has_changed( - user_id, next_id, - ) + self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( - (account_data_type, user_id,) + (account_data_type, user_id) ) result = self._account_data_id_gen.get_current_token() @@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore): Args: next_id(int): The the revision to advance to. """ + def _update(txn): update_max_id_sql = ( "UPDATE account_data_max_stream_id" @@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore): " WHERE stream_id < ?" ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction( - "update_account_data_max_stream_id", - _update, - ) + + return self.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 31248d5e06..6092f600ba 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): self.services_cache = load_appservices( - hs.hostname, - hs.config.app_service_config_files + hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) @@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): pass -class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, - EventsWorkerStore): +class ApplicationServiceTransactionWorkerStore( + ApplicationServiceWorkerStore, EventsWorkerStore +): @defer.inlineCallbacks def get_appservices_by_state(self, state): """Get a list of application services based on their state. @@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, may be empty. """ results = yield self._simple_select_list( - "application_services_state", - dict(state=state), - ["as_id"] + "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, A Deferred which resolves when the state was set successfully. """ return self._simple_upsert( - "application_services_state", - dict(as_id=service.id), - dict(state=state) + "application_services_state", dict(as_id=service.id), dict(state=state) ) def create_appservice_txn(self, service, events): @@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, Returns: AppServiceTransaction: A new transaction. """ + def _create_appservice_txn(txn): # work out new txn id (highest txn id for this service += 1) # The highest id may be the last one sent (in which case it is last_txn) @@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, txn.execute( "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,) + (service.id,), ) highest_txn_id = txn.fetchone()[0] if highest_txn_id is None: @@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", - (service.id, new_txn_id, event_ids) - ) - return AppServiceTransaction( - service=service, id=new_txn_id, events=events + (service.id, new_txn_id, event_ids), ) + return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction( - "create_appservice_txn", - _create_appservice_txn, - ) + return self.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, "appservice: Completing a transaction which has an ID > 1 from " "the last ID sent to this AS. We've either dropped events or " "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", last_txn_id, txn_id, - service.id + "completing_txn=%s service_id=%s", + last_txn_id, + txn_id, + service.id, ) # Set current txn_id for AS to 'txn_id' self._simple_upsert_txn( - txn, "application_services_state", dict(as_id=service.id), - dict(last_txn=txn_id) + txn, + "application_services_state", + dict(as_id=service.id), + dict(last_txn=txn_id), ) # Delete txn self._simple_delete_txn( - txn, "application_services_txns", - dict(txn_id=txn_id, as_id=service.id) + txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) - return self.runInteraction( - "complete_appservice_txn", - _complete_appservice_txn, - ) + return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, A Deferred which resolves to an AppServiceTransaction or None. """ + def _get_oldest_unsent_txn(txn): # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) txn.execute( "SELECT * FROM application_services_txns WHERE as_id=?" " ORDER BY txn_id ASC LIMIT 1", - (service.id,) + (service.id,), ) rows = self.cursor_to_dict(txn) if not rows: @@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, return entry entry = yield self.runInteraction( - "get_oldest_unsent_appservice_txn", - _get_oldest_unsent_txn, + "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) if not entry: @@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, events = yield self._get_events(event_ids) - defer.returnValue(AppServiceTransaction( - service=service, id=entry["txn_id"], events=events - )) + defer.returnValue( + AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + ) def _get_last_txn(self, txn, service_id): txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service_id,) + (service_id,), ) last_txn_id = txn.fetchone() if last_txn_id is None or last_txn_id[0] is None: # no row exists @@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) + return self.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( - "get_new_events_for_appservice", get_new_events_for_appservice_txn, + "get_new_events_for_appservice", get_new_events_for_appservice_txn ) events = yield self._get_events(event_ids) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index a2f8c23a65..b8b8273f73 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore): self._all_done = False def start_doing_background_updates(self): - run_as_background_process( - "background_updates", self._run_background_updates, - ) + run_as_background_process("background_updates", self._run_background_updates) @defer.inlineCallbacks def _run_background_updates(self): logger.info("Starting background schema updates") while True: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) + yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): @defer.inlineCallbacks def _do_background_update(self, update_name, desired_duration_ms): - logger.info("Starting update batch on background update '%s'", - update_name) + logger.info("Starting update batch on background update '%s'", update_name) update_handler = self._background_update_handlers[update_name] @@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = yield self._simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, - retcol="progress_json" + retcol="progress_json", ) progress = json.loads(progress_json) @@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info( "Updating %r. Updated %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", - update_name, items_updated, duration_ms, + update_name, + items_updated, + duration_ms, performance.total_items_per_ms(), performance.average_items_per_ms(), performance.total_item_count, @@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore): Args: update_name (str): Name of update """ + @defer.inlineCallbacks def noop_update(progress, batch_size): yield self._end_background_update(update_name) @@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore): self.register_background_update_handler(update_name, noop_update) - def register_background_index_update(self, update_name, index_name, - table, columns, where_clause=None, - unique=False, - psql_only=False): + def register_background_index_update( + self, + update_name, + index_name, + table, + columns, + where_clause=None, + unique=False, + psql_only=False, + ): """Helper for store classes to do a background index addition To use: @@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore): "name": index_name, "table": table, "columns": ", ".join(columns), - "where_clause": "WHERE " + where_clause if where_clause else "" + "where_clause": "WHERE " + where_clause if where_clause else "", } logger.debug("[SQL] %s", sql) c.execute(sql) @@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore): return self._simple_insert( "background_updates", - {"update_name": update_name, "progress_json": progress_json} + {"update_name": update_name, "progress_json": progress_json}, ) def _end_background_update(self, update_name): diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 9c21362226..bda68de5be 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - max_entries=50000 * CACHE_SIZE_FACTOR, + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) super(ClientIpStore, self).__init__(db_conn, hs) @@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) self.register_background_update_handler( - "user_ips_analyze", - self._analyze_user_ip, + "user_ips_analyze", self._analyze_user_ip ) self.register_background_update_handler( - "user_ips_remove_dupes", - self._remove_user_ip_dupes, + "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index @@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # Drop the old non-unique index self.register_background_update_handler( - "user_ips_drop_nonunique_index", - self._remove_user_ip_nonunique, + "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) @@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): txn = conn.cursor() - txn.execute( - "DROP INDEX IF EXISTS user_ips_user_ip" - ) + txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() yield self.runWithConnection(f) @@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction( - "user_ips_analyze", user_ips_analyze - ) + yield self.runInteraction("user_ips_analyze", user_ips_analyze) yield self._end_background_update("user_ips_analyze") @@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): LIMIT 1 OFFSET ? """, - (begin_last_seen, batch_size) + (begin_last_seen, batch_size), ) row = txn.fetchone() if row: @@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): logger.info( "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", - begin_last_seen, end_last_seen, + begin_last_seen, + end_last_seen, ) def remove(txn): @@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip HAVING count(*) > 1 - """.format(clause), - args + """.format( + clause + ), + args, ) res = txn.fetchall() @@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): DELETE FROM user_ips WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? """, - (user_id, access_token, ip, last_seen) + (user_id, access_token, ip, last_seen), ) if txn.rowcount == count - 1: # We deleted all but one of the duplicate rows, i.e. there @@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): continue elif txn.rowcount >= count: raise Exception( - "We deleted more duplicate rows from 'user_ips' than expected", + "We deleted more duplicate rows from 'user_ips' than expected" ) # The previous step didn't delete enough rows, so we fallback to @@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): DELETE FROM user_ips WHERE user_id = ? AND access_token = ? AND ip = ? """, - (user_id, access_token, ip) + (user_id, access_token, ip), ) # Add in one to be the last_seen @@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen) VALUES (?, ?, ?, ?, ?, ?) """, - (user_id, access_token, ip, device_id, user_agent, last_seen) + (user_id, access_token, ip, device_id, user_agent, last_seen), ) self._background_update_progress_txn( @@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): defer.returnValue(batch_size) @defer.inlineCallbacks - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, - now=None): + def insert_client_ip( + self, user_id, access_token, ip, user_agent, device_id, now=None + ): if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} return self.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, - to_update, + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) - return run_as_background_process( - "update_client_ips", update, - ) + return run_as_background_process("update_client_ips", update) def _update_client_ips_batch_txn(self, txn, to_update): if "user_ips" in self._unsafe_to_upsert_tables or ( @@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): res = yield self.runInteraction( "get_last_client_ip_by_device", self._get_last_client_ip_by_device_txn, - user_id, device_id, + user_id, + device_id, retcols=( "user_id", "access_token", @@ -416,7 +409,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): bindings = [] if device_id is None: where_clauses.append("user_id = ?") - bindings.extend((user_id, )) + bindings.extend((user_id,)) else: where_clauses.append("(user_id = ? AND device_id = ?)") bindings.extend((user_id, device_id)) @@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " "WHERE %(where)s " "GROUP BY user_id, device_id" - ) % { - "where": " OR ".join(where_clauses), - } + ) % {"where": " OR ".join(where_clauses)} sql = ( "SELECT %(retcols)s FROM user_ips " @@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): rows = yield self._simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, - retcols=[ - "access_token", "ip", "user_agent", "last_seen" - ], + retcols=["access_token", "ip", "user_agent", "last_seen"], desc="get_user_ip_and_agents", ) @@ -472,12 +461,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - defer.returnValue(list( - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in iteritems(results) - )) + defer.returnValue( + list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in iteritems(results) + ) + ) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index e6a42a53bb..fed4ea3610 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -57,9 +57,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): " ORDER BY stream_id ASC" " LIMIT ?" ) - txn.execute(sql, ( - user_id, device_id, last_stream_id, current_stream_id, limit - )) + txn.execute( + sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + ) messages = [] for row in txn: stream_pos = row[0] @@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): return (messages, stream_pos) return self.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn, + "get_new_messages_for_device", get_new_messages_for_device_txn ) @defer.inlineCallbacks @@ -146,9 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): " ORDER BY stream_id ASC" " LIMIT ?" ) - txn.execute(sql, ( - destination, last_stream_id, current_stream_id, limit - )) + txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] @@ -172,6 +170,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): Returns: A deferred that resolves when the messages have been deleted. """ + def delete_messages_for_remote_destination_txn(txn): sql = ( "DELETE FROM device_federation_outbox" @@ -181,8 +180,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (destination, up_to_stream_id)) return self.runInteraction( - "delete_device_msgs_for_remote", - delete_messages_for_remote_destination_txn + "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) @@ -200,8 +198,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): ) self.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, - self._background_drop_index_device_inbox, + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) # Map of (user_id, device_id) to the last stream_id that has been @@ -214,8 +211,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): ) @defer.inlineCallbacks - def add_messages_to_device_inbox(self, local_messages_by_user_then_device, - remote_messages_by_destination): + def add_messages_to_device_inbox( + self, local_messages_by_user_then_device, remote_messages_by_destination + ): """Used to send messages from this server. Args: @@ -252,15 +250,10 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() yield self.runInteraction( - "add_messages_to_device_inbox", - add_messages_txn, - now_ms, - stream_id, + "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed( - user_id, stream_id - ) + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id @@ -277,7 +270,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self._simple_select_one_txn( - txn, table="device_federation_inbox", + txn, + table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, retcols=("message_id",), allow_none=True, @@ -288,7 +282,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): # Add an entry for this message_id so that we know we've processed # it. self._simple_insert_txn( - txn, table="device_federation_inbox", + txn, + table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, @@ -311,19 +306,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): stream_id, ) for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed( - user_id, stream_id - ) + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) defer.returnValue(stream_id) - def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, - messages_by_user_then_device): - sql = ( - "UPDATE device_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) + def _add_messages_to_local_device_inbox_txn( + self, txn, stream_id, messages_by_user_then_device + ): + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" txn.execute(sql, (stream_id, stream_id)) local_by_user_then_device = {} @@ -332,10 +322,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. - sql = ( - "SELECT device_id FROM devices" - " WHERE user_id = ?" - ) + sql = "SELECT device_id FROM devices" " WHERE user_id = ?" txn.execute(sql, (user_id,)) message_json = json.dumps(messages_by_device["*"]) for row in txn: @@ -428,9 +415,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() - txn.execute( - "DROP INDEX IF EXISTS device_inbox_stream_id" - ) + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() yield self.runWithConnection(reindex_txn) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index e716dc1437..fd869b934c 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore): table="devices", keyvalues={"user_id": user_id}, retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user" + desc="get_devices_by_user", ) defer.returnValue({d["device_id"]: d for d in devices}) @@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore): return (now_stream_id, []) return self.runInteraction( - "get_devices_by_remote", self._get_devices_by_remote_txn, - destination, from_stream_id, now_stream_id, + "get_devices_by_remote", + self._get_devices_by_remote_txn, + destination, + from_stream_id, + now_stream_id, ) - def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, - now_stream_id): + def _get_devices_by_remote_txn( + self, txn, destination, from_stream_id, now_stream_id + ): sql = """ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? GROUP BY user_id, device_id LIMIT 20 """ - txn.execute( - sql, (destination, from_stream_id, now_stream_id, False) - ) + txn.execute(sql, (destination, from_stream_id, now_stream_id, False)) # maps (user_id, device_id) -> stream_id query_map = {(r[0], r[1]): r[2] for r in txn} @@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore): now_stream_id = max(stream_id for stream_id in itervalues(query_map)) devices = self._get_e2e_device_keys_txn( - txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True + txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, ) prev_sent_id_sql = """ @@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore): """Mark that updates have successfully been sent to the destination. """ return self.runInteraction( - "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, - destination, stream_id, + "mark_as_sent_devices_by_remote", + self._mark_as_sent_devices_by_remote_txn, + destination, + stream_id, ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): @@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id """ - txn.execute(sql, (destination, stream_id,)) + txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() sql = """ @@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore): SET stream_id = ? WHERE destination = ? AND user_id = ? """ - txn.executemany( - sql, ((row[1], destination, row[0],) for row in rows if row[2]) - ) + txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) sql = """ INSERT INTO device_lists_outbound_last_success (destination, user_id, stream_id) VALUES (?, ?, ?) """ txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + sql, ((destination, row[0], row[1]) for row in rows if not row[2]) ) # Delete all sent outbound pokes @@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore): DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (destination, stream_id,)) + txn.execute(sql, (destination, stream_id)) def get_device_stream_token(self): return self._device_list_id_gen.get_current_token() @@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore): def _get_cached_user_device(self, user_id, device_id): content = yield self._simple_select_one_onecol( table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", desc="_get_cached_user_device", ) @@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore): def _get_cached_devices_for_user(self, user_id): devices = yield self._simple_select_list( table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, + keyvalues={"user_id": user_id}, retcols=("device_id", "content"), desc="_get_cached_devices_for_user", ) - defer.returnValue({ - device["device_id"]: db_to_json(device["content"]) - for device in devices - }) + defer.returnValue( + {device["device_id"]: db_to_json(device["content"]) for device in devices} + ) def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user @@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore): """ return self.runInteraction( "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, user_id, + self._get_devices_with_keys_by_user_txn, + user_id, ) def _get_devices_with_keys_by_user_txn(self, txn, user_id): @@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore): user_devices = devices[user_id] results = [] for device_id, device in iteritems(user_devices): - result = { - "device_id": device_id, - } + result = {"device_id": device_id} key_json = device.get("key_json", None) if key_json: @@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore): sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + rows = yield self._execute( + "get_user_whose_devices_changed", None, sql, from_key + ) defer.returnValue(set(row[0] for row in rows)) def get_all_device_list_changes_for_remotes(self, from_key, to_key): @@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore): GROUP BY user_id, destination """ return self._execute( - "get_all_device_list_changes_for_remotes", None, - sql, from_key, to_key + "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @cached(max_entries=10000) @@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore): allow_none=True, ) - @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", - list_name="user_ids", inlineCallbacks=True) + @cachedList( + cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", + inlineCallbacks=True, + ) def get_device_list_last_stream_id_for_remotes(self, user_ids): rows = yield self._simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, - retcols=("user_id", "stream_id",), + retcols=("user_id", "stream_id"), desc="get_device_list_last_stream_id_for_remotes", ) results = {user_id: None for user_id in user_ids} - results.update({ - row["user_id"]: row["stream_id"] for row in rows - }) + results.update({row["user_id"]: row["stream_id"] for row in rows}) defer.returnValue(results) @@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. self.device_id_exists_cache = Cache( - name="device_id_exists", - keylen=2, - max_entries=10000, + name="device_id_exists", keylen=2, max_entries=10000 ) - self._clock.looping_call( - self._prune_old_outbound_device_pokes, 60 * 60 * 1000 - ) + self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) self.register_background_index_update( "device_lists_stream_idx", @@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) @defer.inlineCallbacks - def store_device(self, user_id, device_id, - initial_device_display_name): + def store_device(self, user_id, device_id, initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: @@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): values={ "user_id": user_id, "device_id": device_id, - "display_name": initial_device_display_name + "display_name": initial_device_display_name, }, desc="store_device", or_ignore=True, @@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): self.device_id_exists_cache.prefill(key, True) defer.returnValue(inserted) except Exception as e: - logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" - " display_name=%s(%r) failed: %s", - type(device_id).__name__, device_id, - type(user_id).__name__, user_id, - type(initial_device_display_name).__name__, - initial_device_display_name, e) + logger.error( + "store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, + device_id, + type(user_id).__name__, + user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, + e, + ) raise StoreError(500, "Problem storing device.") @defer.inlineCallbacks @@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ yield self._simple_delete( table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, + keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", ) self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry(self, user_id, device_id, content, - stream_id): + def update_remote_device_list_cache_entry( + self, user_id, device_id, content, stream_id + ): """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's @@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): return self.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, - user_id, device_id, content, stream_id, + user_id, + device_id, + content, + stream_id, ) - def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, - content, stream_id): + def _update_remote_device_list_cache_entry_txn( + self, txn, user_id, device_id, content, stream_id + ): if content.get("deleted"): self._simple_delete_txn( txn, table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + keyvalues={"user_id": user_id, "device_id": device_id}, ) - txn.call_after( - self.device_id_exists_cache.invalidate, (user_id, device_id,) - ) + txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: self._simple_upsert_txn( txn, table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "content": json.dumps(content), - }, - + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"content": json.dumps(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, ) - txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) @@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - values={ - "stream_id": stream_id, - }, - + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, # again, we can assume we are the only thread updating this user's # extremity. lock=False, @@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): return self.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, - user_id, devices, stream_id, + user_id, + devices, + stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, - stream_id): + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): self._simple_delete_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, + txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) self._simple_insert_many_txn( @@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): "content": json.dumps(content), } for content in devices - ] + ], ) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) @@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - values={ - "stream_id": stream_id, - }, - + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, # we don't need to lock, because we can assume we are the only thread # updating this user's extremity. lock=False, @@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( - "add_device_change_to_streams", self._add_device_change_txn, - user_id, device_ids, hosts, stream_id, + "add_device_change_to_streams", + self._add_device_change_txn, + user_id, + device_ids, + hosts, + stream_id, ) defer.returnValue(stream_id) @@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): now = self._clock.time_msec() txn.call_after( - self._device_list_stream_cache.entity_has_changed, - user_id, stream_id, + self._device_list_stream_cache.entity_has_changed, user_id, stream_id ) for host in hosts: txn.call_after( self._device_list_federation_stream_cache.entity_has_changed, - host, stream_id, + host, + stream_id, ) # Delete older entries in the table, as we really only care about @@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids] + [(user_id, device_id, stream_id) for device_id in device_ids], ) self._simple_insert_many_txn( txn, table="device_lists_stream", values=[ - { - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - } + {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} for device_id in device_ids - ] + ], ) self._simple_insert_many_txn( @@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): } for destination in hosts for device_id in device_ids - ] + ], ) def _prune_old_outbound_device_pokes(self): @@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ txn.executemany( - delete_sql, - ( - (yesterday, row[0], row[1], row[2]) - for row in rows - ) + delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) ) # Since we've deleted unsent deltas, we need to remove the entry @@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() - txn.execute( - "DROP INDEX IF EXISTS device_lists_remote_cache_id" - ) - txn.execute( - "DROP INDEX IF EXISTS device_lists_remote_extremeties_id" - ) + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() yield self.runWithConnection(f) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 61a029a53c..201bbd430c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -22,10 +22,7 @@ from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore -RoomAliasMapping = namedtuple( - "RoomAliasMapping", - ("room_id", "room_alias", "servers",) -) +RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): @@ -63,16 +60,12 @@ class DirectoryWorkerStore(SQLBaseStore): defer.returnValue(None) return - defer.returnValue( - RoomAliasMapping(room_id, room_alias.to_string(), servers) - ) + defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers)) def get_room_alias_creator(self, room_alias): return self._simple_select_one_onecol( table="room_aliases", - keyvalues={ - "room_alias": room_alias, - }, + keyvalues={"room_alias": room_alias}, retcol="creator", desc="get_room_alias_creator", ) @@ -101,6 +94,7 @@ class DirectoryStore(DirectoryWorkerStore): Returns: Deferred """ + def alias_txn(txn): self._simple_insert_txn( txn, @@ -115,10 +109,10 @@ class DirectoryStore(DirectoryWorkerStore): self._simple_insert_many_txn( txn, table="room_alias_servers", - values=[{ - "room_alias": room_alias.to_string(), - "server": server, - } for server in servers], + values=[ + {"room_alias": room_alias.to_string(), "server": server} + for server in servers + ], ) self._invalidate_cache_and_stream( @@ -126,9 +120,7 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.runInteraction( - "create_room_alias_association", alias_txn - ) + ret = yield self.runInteraction("create_room_alias_association", alias_txn) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -138,9 +130,7 @@ class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( - "delete_room_alias", - self._delete_room_alias_txn, - room_alias, + "delete_room_alias", self._delete_room_alias_txn, room_alias ) defer.returnValue(room_id) @@ -148,7 +138,7 @@ class DirectoryStore(DirectoryWorkerStore): def _delete_room_alias_txn(self, txn, room_alias): txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", - (room_alias.to_string(),) + (room_alias.to_string(),), ) res = txn.fetchone() @@ -158,31 +148,29 @@ class DirectoryStore(DirectoryWorkerStore): return None txn.execute( - "DELETE FROM room_aliases WHERE room_alias = ?", - (room_alias.to_string(),) + "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) ) txn.execute( "DELETE FROM room_alias_servers WHERE room_alias = ?", - (room_alias.to_string(),) + (room_alias.to_string(),), ) - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (room_id,) - ) + self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) return room_id def update_aliases_for_room(self, old_room_id, new_room_id, creator): def _update_aliases_for_room_txn(txn): sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id,)) + txn.execute(sql, (new_room_id, creator, old_room_id)) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (old_room_id,) ) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (new_room_id,) ) + return self.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 9a3aec759e..521936e3b0 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -23,7 +23,6 @@ from ._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks def get_e2e_room_key(self, user_id, version, room_id, session_id): """Get the encrypted E2E room key for a given session from a given @@ -97,9 +96,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_room_keys( - self, user_id, version, room_id=None, session_id=None - ): + def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -123,10 +120,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): except ValueError: defer.returnValue({'rooms': {}}) - keyvalues = { - "user_id": user_id, - "version": version, - } + keyvalues = {"user_id": user_id, "version": version} if room_id: keyvalues['room_id'] = room_id if session_id: @@ -160,9 +154,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): defer.returnValue(sessions) @defer.inlineCallbacks - def delete_e2e_room_keys( - self, user_id, version, room_id=None, session_id=None - ): + def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -180,19 +172,14 @@ class EndToEndRoomKeyStore(SQLBaseStore): A deferred of the deletion transaction """ - keyvalues = { - "user_id": user_id, - "version": int(version), - } + keyvalues = {"user_id": user_id, "version": int(version)} if room_id: keyvalues['room_id'] = room_id if session_id: keyvalues['session_id'] = session_id yield self._simple_delete( - table="e2e_room_keys", - keyvalues=keyvalues, - desc="delete_e2e_room_keys", + table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @staticmethod @@ -200,7 +187,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions " "WHERE user_id=? AND deleted=0", - (user_id,) + (user_id,), ) row = txn.fetchone() if not row: @@ -238,24 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): result = self._simple_select_one_txn( txn, table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": this_version, - "deleted": 0, - }, - retcols=( - "version", - "algorithm", - "auth_data", - ), + keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, + retcols=("version", "algorithm", "auth_data"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) return result return self.runInteraction( - "get_e2e_room_keys_version_info", - _get_e2e_room_keys_version_info_txn + "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) def create_e2e_room_keys_version(self, user_id, info): @@ -273,7 +251,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _create_e2e_room_keys_version_txn(txn): txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", - (user_id,) + (user_id,), ) current_version = txn.fetchone()[0] if current_version is None: @@ -309,14 +287,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): return self._simple_update( table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": version, - }, - updatevalues={ - "auth_data": json.dumps(info["auth_data"]), - }, - desc="update_e2e_room_keys_version" + keyvalues={"user_id": user_id, "version": version}, + updatevalues={"auth_data": json.dumps(info["auth_data"])}, + desc="update_e2e_room_keys_version", ) def delete_e2e_room_keys_version(self, user_id, version=None): @@ -341,16 +314,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": this_version, - }, - updatevalues={ - "deleted": 1, - } + keyvalues={"user_id": user_id, "version": this_version}, + updatevalues={"deleted": 1}, ) return self.runInteraction( - "delete_e2e_room_keys_version", - _delete_e2e_room_keys_version_txn + "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index e381e472a2..2fabb9e2cb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json class EndToEndKeyWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_device_keys( - self, query_list, include_all_devices=False, - include_deleted_devices=False, + self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. Args: @@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): defer.returnValue({}) results = yield self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, include_deleted_devices, + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, ) for user_id, device_keys in iteritems(results): @@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): defer.returnValue(results) def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, - include_deleted_devices=False, + self, txn, query_list, include_all_devices=False, include_deleted_devices=False ): query_clauses = [] query_params = [] @@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " WHERE %s" ) % ( "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses) + " OR ".join("(" + q + ")" for q in query_clauses), ) txn.execute(sql, query_params) @@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore): table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, - retcols=("algorithm", "key_id", "key_json",), - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, desc="add_e2e_one_time_keys_check", ) - defer.returnValue({ - (row["algorithm"], row["key_id"]): row["key_json"] for row in rows - }) + defer.returnValue( + {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + ) @defer.inlineCallbacks def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): @@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. self._simple_insert_many_txn( - txn, table="e2e_one_time_keys_json", + txn, + table="e2e_one_time_keys_json", values=[ { "user_id": user_id, @@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ], ) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + yield self.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: Dict mapping from algorithm to number of keys for that algorithm. """ + def _count_e2e_one_time_keys(txn): sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" @@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for algorithm, key_count in txn: result[algorithm] = key_count return result - return self.runInteraction( - "count_e2e_one_time_keys", _count_e2e_one_time_keys - ) + + return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): @@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ + def _set_e2e_device_keys_txn(txn): old_key_json = self._simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + keyvalues={"user_id": user_id, "device_id": device_id}, retcol="key_json", allow_none=True, ) @@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._simple_upsert_txn( txn, table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": new_key_json, - } + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, ) return True - return self.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) + return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" + def _claim_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" @@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) return result - return self.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + + return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): keyvalues={"user_id": user_id, "device_id": device_id}, ) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + return self.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index ff5ef97ca8..9d2d519922 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -20,10 +20,7 @@ from ._base import IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = { - "sqlite3": Sqlite3Engine, - "psycopg2": PostgresEngine, -} +SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} def create_engine(database_config): @@ -32,15 +29,12 @@ def create_engine(database_config): if engine_class: # pypy requires psycopg2cffi rather than psycopg2 - if (name == "psycopg2" and - platform.python_implementation() == "PyPy"): + if name == "psycopg2" and platform.python_implementation() == "PyPy": name = "psycopg2cffi" module = importlib.import_module(name) return engine_class(module, database_config) - raise RuntimeError( - "Unsupported database engine '%s'" % (name,) - ) + raise RuntimeError("Unsupported database engine '%s'" % (name,)) __all__ = ["create_engine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index dc3238501c..1b97ee74e3 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -23,7 +23,7 @@ class PostgresEngine(object): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) self.synchronous_commit = database_config.get("synchronous_commit", True) - self._version = None # unknown as yet + self._version = None # unknown as yet def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -31,8 +31,7 @@ class PostgresEngine(object): if rows and rows[0][0] != "UTF8": raise IncorrectDatabaseSetup( "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." - % (rows[0][0],) + "See docs/postgres.rst for more information." % (rows[0][0],) ) def convert_param_style(self, sql): @@ -103,12 +102,6 @@ class PostgresEngine(object): # https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION if numver >= 100000: - return "%i.%i" % ( - numver / 10000, numver % 10000, - ) + return "%i.%i" % (numver / 10000, numver % 10000) else: - return "%i.%i.%i" % ( - numver / 10000, - (numver % 10000) / 100, - numver % 100, - ) + return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 1bcd5b99a4..933bcf42c2 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -82,9 +82,10 @@ class Sqlite3Engine(object): # Following functions taken from: https://github.com/coleifer/peewee + def _parse_match_info(buf): bufsize = len(buf) - return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] + return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)] def _rank(raw_match_info): @@ -98,7 +99,7 @@ def _rank(raw_match_info): phrase_info_idx = 2 + (phrase_num * c * 3) for col_num in range(c): col_idx = phrase_info_idx + (col_num * 3) - x1, x2 = match_info[col_idx:col_idx + 2] + x1, x2 = match_info[col_idx : col_idx + 2] if x1 > 0: score += float(x1) / x2 return score diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index a8d90456e3..956f876572 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -32,8 +32,7 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, - SQLBaseStore): +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -45,7 +44,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, list of events """ return self.get_auth_chain_ids( - event_ids, include_given=include_given, + event_ids, include_given=include_given ).addCallback(self._get_events) def get_auth_chain_ids(self, event_ids, include_given=False): @@ -59,9 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, list of event_ids """ return self.runInteraction( - "get_auth_chain_ids", - self._get_auth_chain_ids_txn, - event_ids, include_given + "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given ) def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): @@ -70,23 +67,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, else: results = set() - base_sql = ( - "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" - ) + base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" front = set(event_ids) while front: new_front = set() front_list = list(front) - chunks = [ - front_list[x:x + 100] - for x in range(0, len(front), 100) - ] + chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] for chunk in chunks: - txn.execute( - base_sql % (",".join(["?"] * len(chunk)),), - chunk - ) + txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) new_front.update([r[0] for r in txn]) new_front -= results @@ -98,9 +87,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, def get_oldest_events_in_room(self, room_id): return self.runInteraction( - "get_oldest_events_in_room", - self._get_oldest_events_in_room_txn, - room_id, + "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): @@ -121,7 +108,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, " GROUP BY b.event_id" ) - txn.execute(sql, (room_id, False,)) + txn.execute(sql, (room_id, False)) return dict(txn) @@ -152,9 +139,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, return self._simple_select_onecol_txn( txn, table="event_backward_extremities", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, retcol="event_id", ) @@ -209,9 +194,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( table="event_forward_extremities", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) @@ -225,14 +208,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, "WHERE f.room_id = ?" ) - txn.execute(sql, (room_id, )) + txn.execute(sql, (room_id,)) results = [] for event_id, depth in txn.fetchall(): hashes = self._get_event_reference_hashes_txn(txn, event_id) prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" + k: encode_base64(v) for k, v in hashes.items() if k == "sha256" } results.append((event_id, prev_hashes, depth)) @@ -242,9 +224,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, """ For hte given room, get the minimum depth we have seen for it. """ return self.runInteraction( - "get_min_depth", - self._get_min_depth_interaction, - room_id, + "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): @@ -300,7 +280,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, if stream_ordering <= self.stream_ordering_month_ago: raise StoreError(400, "stream_ordering too old") - sql = (""" + sql = """ SELECT event_id FROM stream_ordering_to_exterm INNER JOIN ( SELECT room_id, MAX(stream_ordering) AS stream_ordering @@ -308,15 +288,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, WHERE stream_ordering <= ? GROUP BY room_id ) AS rms USING (room_id, stream_ordering) WHERE room_id = ? - """) + """ def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] return self.runInteraction( - "get_forward_extremeties_for_room", - get_forward_extremeties_for_room_txn + "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) def get_backfill_events(self, room_id, event_list, limit): @@ -329,19 +308,21 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, event_list (list) limit (int) """ - return self.runInteraction( - "get_backfill_events", - self._get_backfill_events, room_id, event_list, limit - ).addCallback( - self._get_events - ).addCallback( - lambda l: sorted(l, key=lambda e: -e.depth) + return ( + self.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, + ) + .addCallback(self._get_events) + .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) ) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug( - "_get_backfill_events: %s, %s, %s", - room_id, repr(event_list), limit + "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit ) event_results = set() @@ -364,10 +345,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, depth = self._simple_select_one_onecol_txn( txn, table="events", - keyvalues={ - "event_id": event_id, - "room_id": room_id, - }, + keyvalues={"event_id": event_id, "room_id": room_id}, retcol="depth", allow_none=True, ) @@ -386,10 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, event_results.add(event_id) - txn.execute( - query, - (event_id, False, limit - len(event_results)) - ) + txn.execute(query, (event_id, False, limit - len(event_results))) for row in txn: if row[1] not in event_results: @@ -398,18 +373,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, return event_results @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, - limit): + def get_missing_events(self, room_id, earliest_events, latest_events, limit): ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, - room_id, earliest_events, latest_events, limit, + room_id, + earliest_events, + latest_events, + limit, ) events = yield self._get_events(ids) defer.returnValue(events) - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, - limit): + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): seen_events = set(earliest_events) front = set(latest_events) - seen_events @@ -425,8 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, new_front = set() for event_id in front: txn.execute( - query, - (room_id, event_id, False, limit - len(event_results)) + query, (room_id, event_id, False, limit - len(event_results)) ) new_results = set(t[0] for t in txn) - seen_events @@ -457,12 +432,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, column="prev_event_id", iterable=event_ids, retcols=("event_id",), - desc="get_successor_events" + desc="get_successor_events", ) - defer.returnValue([ - row["event_id"] for row in rows - ]) + defer.returnValue([row["event_id"] for row in rows]) class EventFederationStore(EventFederationWorkerStore): @@ -481,12 +454,11 @@ class EventFederationStore(EventFederationWorkerStore): super(EventFederationStore, self).__init__(db_conn, hs) self.register_background_update_handler( - self.EVENT_AUTH_STATE_ONLY, - self._background_delete_non_state_event_auth, + self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000, + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) def _update_min_depth_for_room_txn(self, txn, room_id, depth): @@ -498,12 +470,8 @@ class EventFederationStore(EventFederationWorkerStore): self._simple_upsert_txn( txn, table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, ) def _handle_mult_prev_events(self, txn, events): @@ -553,11 +521,15 @@ class EventFederationStore(EventFederationWorkerStore): " )" ) - txn.executemany(query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ]) + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], + ) query = ( "DELETE FROM event_backward_extremities" @@ -566,16 +538,17 @@ class EventFederationStore(EventFederationWorkerStore): txn.executemany( query, [ - (ev.event_id, ev.room_id) for ev in events + (ev.event_id, ev.room_id) + for ev in events if not ev.internal_metadata.is_outlier() - ] + ], ) def _delete_old_forward_extrem_cache(self): def _delete_old_forward_extrem_cache_txn(txn): # Delete entries older than a month, while making sure we don't delete # the only entries for a room. - sql = (""" + sql = """ DELETE FROM stream_ordering_to_exterm WHERE room_id IN ( @@ -583,11 +556,11 @@ class EventFederationStore(EventFederationWorkerStore): FROM stream_ordering_to_exterm WHERE stream_ordering > ? ) AND stream_ordering < ? - """) + """ txn.execute( - sql, - (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) ) + return run_as_background_process( "delete_old_forward_extrem_cache", self.runInteraction, @@ -597,9 +570,7 @@ class EventFederationStore(EventFederationWorkerStore): def clean_room_for_join(self, room_id): return self.runInteraction( - "clean_room_for_join", - self._clean_room_for_join_txn, - room_id, + "clean_room_for_join", self._clean_room_for_join_txn, room_id ) def _clean_room_for_join_txn(self, txn, room_id): @@ -635,7 +606,7 @@ class EventFederationStore(EventFederationWorkerStore): ) """ - txn.execute(sql, (min_stream_id, max_stream_id,)) + txn.execute(sql, (min_stream_id, max_stream_id)) new_progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 6840320641..a729f3e067 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -31,7 +31,9 @@ logger = logging.getLogger(__name__) DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] DEFAULT_HIGHLIGHT_ACTION = [ - "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, ] @@ -91,25 +93,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id + self, room_id, user_id, last_read_event_id ): ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, - room_id, user_id, last_read_event_id + room_id, + user_id, + last_read_event_id, ) defer.returnValue(ret) - def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, - last_read_event_id): + def _get_unread_counts_by_receipt_txn( + self, txn, room_id, user_id, last_read_event_id + ): sql = ( "SELECT stream_ordering" " FROM events" " WHERE room_id = ? AND event_id = ?" ) - txn.execute( - sql, (room_id, last_read_event_id) - ) + txn.execute(sql, (room_id, last_read_event_id)) results = txn.fetchall() if len(results) == 0: return {"notify_count": 0, "highlight_count": 0} @@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): row = txn.fetchone() notify_count = row[0] if row else 0 - txn.execute(""" + txn.execute( + """ SELECT notif_count FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, (room_id, user_id, stream_ordering,)) + """, + (room_id, user_id, stream_ordering), + ) rows = txn.fetchall() if rows: notify_count += rows[0][0] @@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): row = txn.fetchone() highlight_count = row[0] if row else 0 - return { - "notify_count": notify_count, - "highlight_count": highlight_count, - } + return {"notify_count": notify_count, "highlight_count": highlight_count} @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): @@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] + ret = yield self.runInteraction("get_push_action_users_in_range", f) defer.returnValue(ret) @@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.stream_ordering <= ?" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) - args = [ - user_id, user_id, - min_stream_ordering, max_stream_ordering, limit, - ] + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() + after_read_receipt = yield self.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.stream_ordering <= ?" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) - args = [ - user_id, user_id, - min_stream_ordering, max_stream_ordering, limit, - ] + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() + no_read_receipt = yield self.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): "room_id": row[1], "stream_ordering": row[2], "actions": _deserialize_action(row[3], row[4]), - } for row in after_read_receipt + no_read_receipt + } + for row in after_read_receipt + no_read_receipt ] # Now sort it so it's ordered correctly, since currently it will @@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.stream_ordering <= ?" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [ - user_id, user_id, - min_stream_ordering, max_stream_ordering, limit, - ] + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() + after_read_receipt = yield self.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.stream_ordering <= ?" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [ - user_id, user_id, - min_stream_ordering, max_stream_ordering, limit, - ] + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() + no_read_receipt = yield self.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): "stream_ordering": row[2], "actions": _deserialize_action(row[3], row[4]), "received_ts": row[5], - } for row in after_read_receipt + no_read_receipt + } + for row in after_read_receipt + no_read_receipt ] # Now sort it so it's ordered correctly, since currently it will @@ -386,6 +384,36 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Now return the first `limit` defer.returnValue(notifs[:limit]) + def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + """A fast check to see if there might be something to push for the + user since the given stream ordering. May return false positives. + + Useful to know whether to bother starting a pusher on start up or not. + + Args: + user_id (str) + min_stream_ordering (int) + + Returns: + Deferred[bool]: True if there may be push to process, False if + there definitely isn't. + """ + + def _get_if_maybe_push_in_range_for_user_txn(txn): + sql = """ + SELECT 1 FROM event_push_actions + WHERE user_id = ? AND stream_ordering > ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, min_stream_ordering)) + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_maybe_push_in_range_for_user", + _get_if_maybe_push_in_range_for_user_txn, + ) + def add_push_actions_to_staging(self, event_id, user_id_actions): """Add the push actions for the event to the push action staging area. @@ -424,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): VALUES (?, ?, ?, ?, ?) """ - txn.executemany(sql, ( - _gen_entry(user_id, actions) - for user_id, actions in iteritems(user_id_actions) - )) + txn.executemany( + sql, + ( + _gen_entry(user_id, actions) + for user_id, actions in iteritems(user_id_actions) + ), + ) return self.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn @@ -445,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): try: res = yield self._simple_delete( table="event_push_actions_staging", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", ) defer.returnValue(res) @@ -456,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # another exception here really isn't helpful - there's nothing # the caller can do about it. Just log the exception and move on. logger.exception( - "Error removing push actions after event persistence failure", + "Error removing push actions after event persistence failure" ) def _find_stream_orderings_for_times(self): @@ -473,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 ) logger.info( - "Found stream ordering 1 month ago: it's %d", - self.stream_ordering_month_ago + "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago ) logger.info("Searching for stream ordering 1 day ago") self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 ) logger.info( - "Found stream ordering 1 day ago: it's %d", - self.stream_ordering_day_ago + "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) def find_first_stream_ordering_after_ts(self, ts): @@ -601,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): index_name="event_push_actions_highlights_index", table="event_push_actions", columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" + where_clause="highlight=1", ) self._doing_notif_rotation = False self._rotate_notif_loop = self._clock.looping_call( - self._start_rotate_notifs, 30 * 60 * 1000, + self._start_rotate_notifs, 30 * 60 * 1000 ) - def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, - all_events_and_contexts): + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): """Handles moving push actions from staging table to main event_push_actions table for all events in `events_and_contexts`. @@ -637,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore): """ if events_and_contexts: - txn.executemany(sql, ( + txn.executemany( + sql, ( - event.room_id, event.internal_metadata.stream_ordering, - event.depth, event.event_id, - ) - for event, _ in events_and_contexts - )) + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) for event, _ in events_and_contexts: user_ids = self._simple_select_onecol_txn( txn, table="event_push_actions_staging", - keyvalues={ - "event_id": event.event_id, - }, + keyvalues={"event_id": event.event_id}, retcol="user_id", ) for uid in user_ids: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid,) + (event.room_id, uid), ) # Now we delete the staging area for *all* events that were being # persisted. txn.executemany( "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ( - (event.event_id,) - for event, _ in all_events_and_contexts - ) + ((event.event_id,) for event, _ in all_events_and_contexts), ) @defer.inlineCallbacks - def get_push_actions_for_user(self, user_id, before=None, limit=50, - only_highlight=False): + def get_push_actions_for_user( + self, user_id, before=None, limit=50, only_highlight=False + ): def f(txn): before_clause = "" if before: @@ -697,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" - % (before_clause,) + " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) return self.cursor_to_dict(txn) - push_actions = yield self.runInteraction( - "get_push_actions_for_user", f - ) + push_actions = yield self.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) defer.returnValue(push_actions) @@ -723,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) txn.execute(sql, (stream_ordering,)) return txn.fetchone() + result = yield self.runInteraction("get_time_of_last_push_action_before", f) defer.returnValue(result[0] if result else None) @@ -731,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction( - "get_latest_push_action_stream_ordering", f - ) + + result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) defer.returnValue(result[0] or 0) def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,) + (room_id,), ) txn.execute( "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id) + (room_id, event_id), ) - def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - stream_ordering): + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): """ Purges old push actions for a user and room before a given stream_ordering. @@ -765,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): """ txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id, ) + (room_id, user_id), ) # We need to join on the events table to get the received_ts for @@ -781,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " WHERE user_id = ? AND room_id = ? AND " " stream_ordering <= ?" " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago) + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), ) - txn.execute(""" + txn.execute( + """ DELETE FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, (room_id, user_id, stream_ordering)) + """, + (room_id, user_id, stream_ordering), + ) def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) @@ -803,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): logger.info("Rotating notifications") caught_up = yield self.runInteraction( - "_rotate_notifs", - self._rotate_notifs_txn + "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break @@ -826,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # We don't to try and rotate millions of rows at once, so we cap the # maximum stream ordering we'll rotate before. - txn.execute(""" + txn.execute( + """ SELECT stream_ordering FROM event_push_actions WHERE stream_ordering > ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, (old_rotate_stream_ordering, self._rotate_count)) + """, + (old_rotate_stream_ordering, self._rotate_count), + ) stream_row = txn.fetchone() if stream_row: offset_stream_ordering, = stream_row @@ -874,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) rows = txn.fetchall() logger.info("Rotating notifications, handling %d rows", len(rows)) @@ -892,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): "notif_count": row[2], "stream_ordering": row[3], } - for row in rows if row[4] is None - ] + for row in rows + if row[4] is None + ], ) txn.executemany( @@ -901,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore): UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) + ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), ) txn.execute( "DELETE FROM event_push_actions" " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", - (old_rotate_stream_ordering, rotate_to_stream_ordering,) + (old_rotate_stream_ordering, rotate_to_stream_ordering), ) logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) txn.execute( "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", - (rotate_to_stream_ordering,) + (rotate_to_stream_ordering,), ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 428300ea0a..7a7f841c6c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -30,7 +30,6 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -# these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process @@ -51,8 +50,11 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) persist_event_counter = Counter("synapse_storage_events_persisted_events", "") -event_counter = Counter("synapse_storage_events_persisted_events_sep", "", - ["type", "origin_type", "origin_entity"]) +event_counter = Counter( + "synapse_storage_events_persisted_events_sep", + "", + ["type", "origin_type", "origin_entity"], +) # The number of times we are recalculating the current state state_delta_counter = Counter("synapse_storage_events_state_delta", "") @@ -60,13 +62,15 @@ state_delta_counter = Counter("synapse_storage_events_state_delta", "") # The number of times we are recalculating state when there is only a # single forward extremity state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "") + "synapse_storage_events_state_delta_single_event", "" +) # The number of times we are reculating state when we could have resonably # calculated the delta when we calculated the state for an event we were # persisting. state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "") + "synapse_storage_events_state_delta_reuse_delta", "" +) def encode_json(json_object): @@ -75,7 +79,7 @@ def encode_json(json_object): """ out = frozendict_json_encoder.encode(json_object) if isinstance(out, bytes): - out = out.decode('utf8') + out = out.decode("utf8") return out @@ -84,9 +88,9 @@ class _EventPeristenceQueue(object): concurrent transaction per room. """ - _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "backfilled", "deferred", - )) + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) def __init__(self): self._event_persist_queues = {} @@ -119,11 +123,13 @@ class _EventPeristenceQueue(object): deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - queue.append(self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - )) + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) return deferred.observe() @@ -191,6 +197,7 @@ def _retry_on_integrity_error(func): Args: func: function that returns a Deferred and accepts a `delete_existing` arg """ + @wraps(func) @defer.inlineCallbacks def f(self, *args, **kwargs): @@ -206,8 +213,12 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, - BackgroundUpdateStore): +class EventsStore( + StateGroupWorkerStore, + EventFederationStore, + EventsWorkerStore, + BackgroundUpdateStore, +): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -265,8 +276,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore deferreds = [] for room_id, evs_ctxs in iteritems(partitioned): d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, - backfilled=backfilled, + room_id, evs_ctxs, backfilled=backfilled ) deferreds.append(d) @@ -296,8 +306,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore and the stream ordering of the latest persisted event """ deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], - backfilled=backfilled, + event.room_id, [(event, context)], backfilled=backfilled ) self._maybe_start_persisting(event.room_id) @@ -312,16 +321,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore def persisting_queue(item): with Measure(self._clock, "persist_events"): yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, + item.events_and_contexts, backfilled=item.backfilled ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events(self, events_and_contexts, backfilled=False, - delete_existing=False): + def _persist_events( + self, events_and_contexts, backfilled=False, delete_existing=False + ): """Persist events to db Args: @@ -345,13 +354,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore ) with stream_ordering_manager as stream_orderings: - for (event, context), stream, in zip( - events_and_contexts, stream_orderings - ): + for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream chunks = [ - events_and_contexts[x:x + 100] + events_and_contexts[x : x + 100] for x in range(0, len(events_and_contexts), 100) ] @@ -445,12 +452,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_delta_reuse_delta_counter.inc() break - logger.info( - "Calculating state delta for room %s", room_id, - ) + logger.info("Calculating state delta for room %s", room_id) with Measure( - self._clock, - "persist_events.get_new_state_after_events", + self._clock, "persist_events.get_new_state_after_events" ): res = yield self._get_new_state_after_events( room_id, @@ -470,11 +474,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_delta_for_room[room_id] = ([], delta_ids) elif current_state is not None: with Measure( - self._clock, - "persist_events.calculate_state_delta", + self._clock, "persist_events.calculate_state_delta" ): delta = yield self._calculate_state_delta( - room_id, current_state, + room_id, current_state ) state_delta_for_room[room_id] = delta @@ -498,7 +501,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering, + chunk[-1][0].internal_metadata.stream_ordering ) for event, context in chunk: @@ -515,9 +518,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore event_counter.labels(event.type, origin_type, origin_entity).inc() for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill( - (room_id, ), new_state - ) + self.get_current_state_ids.prefill((room_id,), new_state) for room_id, latest_event_ids in iteritems(new_forward_extremeties): self.get_latest_event_ids_in_room.prefill( @@ -535,8 +536,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # we're only interested in new events which aren't outliers and which aren't # being rejected. new_events = [ - event for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() and not ctx.rejected + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected and not event.internal_metadata.is_soft_failed() ] @@ -544,15 +547,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore result = set(latest_event_ids) # add all the new events to the list - result.update( - event.event_id for event in new_events - ) + result.update(event.event_id for event in new_events) # Now remove all events which are prev_events of any of the new events result.difference_update( - e_id - for event in new_events - for e_id in event.prev_event_ids() + e_id for event in new_events for e_id in event.prev_event_ids() ) # Finally, remove any events which are prev_events of any existing events. @@ -592,17 +591,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore results.extend(r[0] for r in txn) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_events_which_are_prevs", - _get_events, - chunk, - ) + yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) defer.returnValue(results) @defer.inlineCallbacks - def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, - new_latest_event_ids): + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): """Calculate the current state dict after adding some new events to a room @@ -642,7 +638,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if not ev.internal_metadata.is_outlier(): raise Exception( "Context for new event %s has no state " - "group" % (ev.event_id, ), + "group" % (ev.event_id,) ) continue @@ -682,9 +678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if missing_event_ids: # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events( - missing_event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(missing_event_ids) event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids @@ -710,9 +704,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_state_group = next(iter(new_state_groups)) old_state_group = next(iter(old_state_groups)) - delta_ids = state_group_deltas.get( - (old_state_group, new_state_group,), None - ) + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) if delta_ids is not None: # We have a delta from the existing to new current state, # so lets just return that. If we happen to already have @@ -735,9 +727,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Ok, we need to defer to the state handler to resolve our state sets. - state_groups = { - sg: state_groups_map[sg] for sg in new_state_groups - } + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} events_map = {ev.event_id: ev for ev, _ in events_context} @@ -755,8 +745,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self) + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self), ) defer.returnValue((res.state, None)) @@ -774,22 +767,26 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore """ existing_state = yield self.get_current_state_ids(room_id) - to_delete = [ - key for key in existing_state - if key not in current_state - ] + to_delete = [key for key in existing_state if key not in current_state] to_insert = { - key: ev_id for key, ev_id in iteritems(current_state) + key: ev_id + for key, ev_id in iteritems(current_state) if ev_id != existing_state.get(key) } defer.returnValue((to_delete, to_insert)) @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, state_delta_for_room={}, - new_forward_extremeties={}): + def _persist_events_txn( + self, + txn, + events_and_contexts, + backfilled, + delete_existing=False, + state_delta_for_room={}, + new_forward_extremeties={}, + ): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -816,9 +813,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore """ all_events_and_contexts = events_and_contexts + min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) self._update_forward_extremities_txn( txn, @@ -828,20 +826,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Ensure that we don't have the same event twice. events_and_contexts = self._filter_events_and_contexts_for_duplicates( - events_and_contexts, + events_and_contexts ) self._update_room_depths_txn( - txn, - events_and_contexts=events_and_contexts, - backfilled=backfilled, + txn, events_and_contexts=events_and_contexts, backfilled=backfilled ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. events_and_contexts = self._update_outliers_txn( - txn, - events_and_contexts=events_and_contexts, + txn, events_and_contexts=events_and_contexts ) # From this point onwards the events are only events that we haven't @@ -852,15 +847,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # for these events so we can reinsert them. # This gets around any problems with some tables already having # entries. - self._delete_existing_rows_txn( - txn, - events_and_contexts=events_and_contexts, - ) + self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn( - txn, - events_and_contexts=events_and_contexts, - ) + self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) @@ -889,8 +878,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( - txn, - events_and_contexts=events_and_contexts, + txn, events_and_contexts=events_and_contexts ) # From this point onwards the events are only ones that weren't @@ -903,7 +891,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore backfilled=backfilled, ) - def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): + def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): for room_id, current_state_tuple in iteritems(state_delta_by_room): to_delete, to_insert = current_state_tuple @@ -912,6 +900,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # that we can use it to calculate the `prev_event_id`. (This # allows us to not have to pull out the existing state # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # sql = """ INSERT INTO current_state_delta_stream (stream_id, room_id, type, state_key, event_id, prev_event_id) @@ -920,22 +914,40 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.executemany(sql, ( + txn.executemany( + sql, ( - max_stream_order, room_id, etype, state_key, None, - room_id, etype, state_key, - ) - for etype, state_key in to_delete - # We sanity check that we're deleting rather than updating - if (etype, state_key) not in to_insert - )) - txn.executemany(sql, ( + ( + stream_id, + room_id, + etype, + state_key, + None, + room_id, + etype, + state_key, + ) + for etype, state_key in to_delete + # We sanity check that we're deleting rather than updating + if (etype, state_key) not in to_insert + ), + ) + txn.executemany( + sql, ( - max_stream_order, room_id, etype, state_key, ev_id, - room_id, etype, state_key, - ) - for (etype, state_key), ev_id in iteritems(to_insert) - )) + ( + stream_id, + room_id, + etype, + state_key, + ev_id, + room_id, + etype, + state_key, + ) + for (etype, state_key), ev_id in iteritems(to_insert) + ), + ) # Now we actually update the current_state_events table @@ -964,7 +976,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, - room_id, max_stream_order, + room_id, + stream_id, ) # Invalidate the various caches @@ -980,28 +993,27 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if ev_type == EventTypes.Member ) + for member in members_changed: + txn.call_after( + self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) + ) + self._invalidate_state_caches_and_stream(txn, room_id, members_changed) - def _update_forward_extremities_txn(self, txn, new_forward_extremities, - max_stream_order): + def _update_forward_extremities_txn( + self, txn, new_forward_extremities, max_stream_order + ): for room_id, new_extrem in iteritems(new_forward_extremities): self._simple_delete_txn( - txn, - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - ) - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) self._simple_insert_many_txn( txn, table="event_forward_extremities", values=[ - { - "event_id": ev_id, - "room_id": room_id, - } + {"event_id": ev_id, "room_id": room_id} for room_id, new_extrem in iteritems(new_forward_extremities) for ev_id in new_extrem ], @@ -1021,7 +1033,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore } for room_id, new_extrem in iteritems(new_forward_extremities) for event_id in new_extrem - ] + ], ) @classmethod @@ -1065,7 +1077,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if not backfilled: txn.call_after( self._events_stream_cache.entity_has_changed, - event.room_id, event.internal_metadata.stream_ordering, + event.room_id, + event.internal_metadata.stream_ordering, ) if not event.internal_metadata.is_outlier() and not context.rejected: @@ -1092,16 +1105,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore are already in the events table. """ txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( - ",".join(["?"] * len(events_and_contexts)), - ), - [event.event_id for event, _ in events_and_contexts] + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" + % (",".join(["?"] * len(events_and_contexts)),), + [event.event_id for event, _ in events_and_contexts], ) - have_persisted = { - event_id: outlier - for event_id, outlier in txn - } + have_persisted = {event_id: outlier for event_id, outlier in txn} to_remove = set() for event, context in events_and_contexts: @@ -1128,18 +1137,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.exception("") raise - metadata_json = encode_json( - event.internal_metadata.get_dict() - ) + metadata_json = encode_json(event.internal_metadata.get_dict()) sql = ( - "UPDATE event_json SET internal_metadata = ?" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (metadata_json, event.event_id,) + "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" ) + txn.execute(sql, (metadata_json, event.event_id)) # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. @@ -1152,25 +1155,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "event_stream_ordering": stream_order, "event_id": event.event_id, "state_group": state_group_id, - } + }, ) - sql = ( - "UPDATE events SET outlier = ?" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (False, event.event_id,) - ) + sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" + txn.execute(sql, (False, event.event_id)) # Update the event_backward_extremities table now that this # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) - return [ - ec for ec in events_and_contexts if ec[0] not in to_remove - ] + return [ec for ec in events_and_contexts if ec[0] not in to_remove] @classmethod def _delete_existing_rows_txn(cls, txn, events_and_contexts): @@ -1181,39 +1176,33 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.info("Deleting existing") for table in ( - "events", - "event_auth", - "event_json", - "event_content_hashes", - "event_destinations", - "event_edge_hashes", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_signatures", - "event_to_state_groups", - "guest_access", - "history_visibility", - "local_invites", - "room_names", - "state_events", - "rejections", - "redactions", - "room_memberships", - "topics" + "events", + "event_auth", + "event_json", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "event_to_state_groups", + "guest_access", + "history_visibility", + "local_invites", + "room_names", + "state_events", + "rejections", + "redactions", + "room_memberships", + "topics", ): txn.executemany( "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] + [(ev.event_id,) for ev, _ in events_and_contexts], ) - for table in ( - "event_push_actions", - ): + for table in ("event_push_actions",): txn.executemany( "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts] + [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], ) def _store_event_txn(self, txn, events_and_contexts): @@ -1296,17 +1285,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore for event, context in events_and_contexts: if context.rejected: # Insert the event_id into the rejections table - self._store_rejections_txn( - txn, event.event_id, context.rejected - ) + self._store_rejections_txn(txn, event.event_id, context.rejected) to_remove.add(event) - return [ - ec for ec in events_and_contexts if ec[0] not in to_remove - ] + return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _update_metadata_tables_txn(self, txn, events_and_contexts, - all_events_and_contexts, backfilled): + def _update_metadata_tables_txn( + self, txn, events_and_contexts, all_events_and_contexts, backfilled + ): """Update all the miscellaneous tables for new events Args: @@ -1342,8 +1328,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( - txn, - events=[event for event, _ in events_and_contexts], + txn, events=[event for event, _ in events_and_contexts] ) for event, _ in events_and_contexts: @@ -1401,11 +1386,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_values.append(vals) - self._simple_insert_many_txn( - txn, - table="state_events", - values=state_values, - ) + self._simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1416,10 +1397,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows = [] N = 200 for i in range(0, len(events_and_contexts), N): - ev_map = { - e[0].event_id: e[0] - for e in events_and_contexts[i:i + N] - } + ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} if not ev_map: break @@ -1439,14 +1417,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(_EventCacheEntry( - event=event, - redacted_event=None, - )) + to_prefill.append( + _EventCacheEntry(event=event, redacted_event=None) + ) def prefill(): for cache_entry in to_prefill: self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + txn.call_after(prefill) def _store_redaction(self, txn, event): @@ -1454,7 +1432,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.call_after(self._invalidate_get_event_cache, event.redacts) txn.execute( "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", - (event.event_id, event.redacts) + (event.event_id, event.redacts), ) @defer.inlineCallbacks @@ -1465,6 +1443,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore If it has been significantly less or more than one day since the last call to this function, it will return None. """ + def _count_messages(txn): sql = """ SELECT COALESCE(COUNT(*), 0) FROM events @@ -1492,7 +1471,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore AND stream_ordering > ? """ - txn.execute(sql, (like_clause, self.stream_ordering_day_ago,)) + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) count, = txn.fetchone() return count @@ -1557,18 +1536,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore update_rows.append((sender, contains_url, event_id)) - sql = ( - "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - ) + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index:index + INSERT_CLUMP_SIZE] + clump = update_rows[index : index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows) + "rows_inserted": rows_inserted + len(rows), } self._background_update_progress_txn( @@ -1613,10 +1590,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows_to_update = [] - chunks = [ - event_ids[i:i + 100] - for i in range(0, len(event_ids), 100) - ] + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: ev_rows = self._simple_select_many_txn( txn, @@ -1639,18 +1613,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows_to_update.append((origin_server_ts, event_id)) - sql = ( - "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - ) + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index:index + INSERT_CLUMP_SIZE] + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update) + "rows_inserted": rows_inserted + len(rows_to_update), } self._background_update_progress_txn( @@ -1714,6 +1686,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_event_updates.extend(txn) return new_event_updates + return self.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1756,13 +1729,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_event_updates.extend(txn.fetchall()) return new_event_updates + return self.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @cached(num_args=5, max_entries=10) - def get_all_new_events(self, last_backfill_id, last_forward_id, - current_backfill_id, current_forward_id, limit): + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): """Get all the new events that have arrived at the server either as new events or as backfilled events""" have_backfill_events = last_backfill_id != current_backfill_id @@ -1837,14 +1817,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore backward_ex_outliers = [] return AllNewEventsResult( - new_forward_events, new_backfill_events, - forward_ex_outliers, backward_ex_outliers, + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, ) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def purge_history( - self, room_id, token, delete_local_events, - ): + def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point Args: @@ -1860,28 +1841,24 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore return self.runInteraction( "purge_history", - self._purge_history_txn, room_id, token, + self._purge_history_txn, + room_id, + token, delete_local_events, ) - def _purge_history_txn( - self, txn, room_id, token_str, delete_local_events, - ): + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): token = RoomStreamToken.parse(token_str) # Tables that should be pruned: # event_auth # event_backward_extremities - # event_content_hashes - # event_destinations - # event_edge_hashes # event_edges # event_forward_extremities # event_json # event_push_actions # event_reference_hashes # event_search - # event_signatures # event_to_state_groups # events # rejections @@ -1913,7 +1890,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "ON e.event_id = f.event_id " "AND e.room_id = f.room_id " "WHERE f.room_id = ?", - (room_id,) + (room_id,), ) rows = txn.fetchall() max_depth = max(row[1] for row in rows) @@ -1934,10 +1911,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore should_delete_expr += " AND event_id NOT LIKE ?" # We include the parameter twice since we use the expression twice - should_delete_params += ( - "%:" + self.hs.hostname, - "%:" + self.hs.hostname, - ) + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) should_delete_params += (room_id, token.topological) @@ -1948,10 +1922,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore " SELECT event_id, %s" " FROM events AS e LEFT JOIN state_events USING (event_id)" " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % ( - should_delete_expr, - should_delete_expr, - ), + % (should_delete_expr, should_delete_expr), should_delete_params, ) @@ -1961,23 +1932,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # the should_delete / shouldn't_delete subsets txn.execute( "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)", + " ON events_to_purge(should_delete)" ) # We do joins against events_to_purge for e.g. calculating state # groups to purge, etc., so lets make an index. - txn.execute( - "CREATE INDEX events_to_purge_id" - " ON events_to_purge(event_id)", - ) + txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") - txn.execute( - "SELECT event_id, should_delete FROM events_to_purge" - ) + txn.execute("SELECT event_id, should_delete FROM events_to_purge") event_rows = txn.fetchall() logger.info( "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), sum(1 for e in event_rows if e[1]), + len(event_rows), + sum(1 for e in event_rows if e[1]), ) logger.info("[purge] Finding new backward extremities") @@ -1989,24 +1956,21 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "SELECT DISTINCT e.event_id FROM events_to_purge AS e" " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL", + " WHERE ep2.event_id IS NULL" ) new_backwards_extrems = txn.fetchall() logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", - (room_id,) + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) ) # Update backward extremeties txn.executemany( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", - [ - (room_id, event_id) for event_id, in new_backwards_extrems - ] + [(room_id, event_id) for event_id, in new_backwards_extrems], ) logger.info("[purge] finding redundant state groups") @@ -2014,28 +1978,25 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Get all state groups that are referenced by events that are to be # deleted. We then go and check if they are referenced by other events # or state groups, and if not we delete them. - txn.execute(""" + txn.execute( + """ SELECT DISTINCT state_group FROM events_to_purge INNER JOIN event_to_state_groups USING (event_id) - """) + """ + ) referenced_state_groups = set(sg for sg, in txn) logger.info( - "[purge] found %i referenced state groups", - len(referenced_state_groups), + "[purge] found %i referenced state groups", len(referenced_state_groups) ) logger.info("[purge] finding state groups that can be deleted") - state_groups_to_delete, remaining_state_groups = ( - self._find_unreferenced_groups_during_purge( - txn, referenced_state_groups, - ) - ) + _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) + state_groups_to_delete, remaining_state_groups = _ logger.info( - "[purge] found %i state groups to delete", - len(state_groups_to_delete), + "[purge] found %i state groups to delete", len(state_groups_to_delete) ) logger.info( @@ -2047,25 +2008,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # groups to non delta versions. for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], - ) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] self._simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={ - "state_group": sg, - } + txn, table="state_groups_state", keyvalues={"state_group": sg} ) self._simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={ - "state_group": sg, - } + txn, table="state_group_edges", keyvalues={"state_group": sg} ) self._simple_insert_many_txn( @@ -2099,23 +2050,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "WHERE event_id IN (SELECT event_id from events_to_purge)" ) for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, ( - event_id, - )) + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) # Delete all remote non-state events for table in ( "events", "event_json", "event_auth", - "event_content_hashes", - "event_destinations", - "event_edge_hashes", "event_edges", "event_forward_extremities", "event_reference_hashes", "event_search", - "event_signatures", "rejections", ): logger.info("[purge] removing events from %s", table) @@ -2123,21 +2068,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.execute( "DELETE FROM %s WHERE event_id IN (" " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), + ")" % (table,) ) # event_push_actions lacks an index on event_id, and has one on # (room_id, event_id) instead. - for table in ( - "event_push_actions", - ): + for table in ("event_push_actions",): logger.info("[purge] removing events from %s", table) txn.execute( "DELETE FROM %s WHERE room_id = ? AND event_id IN (" " SELECT event_id FROM events_to_purge WHERE should_delete" ")" % (table,), - (room_id, ) + (room_id,), ) # Mark all state and own events as outliers @@ -2162,27 +2105,28 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # extremities. However, the events in event_backward_extremities # are ones we don't have yet so we need to look at the events that # point to it via event_edges table. - txn.execute(""" + txn.execute( + """ SELECT COALESCE(MIN(depth), 0) FROM event_backward_extremities AS eb INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id INNER JOIN events AS e ON e.event_id = eg.event_id WHERE eb.room_id = ? - """, (room_id,)) + """, + (room_id,), + ) min_depth, = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) txn.execute( "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id,) + (min_depth, room_id), ) # finally, drop the temp table. this will commit the txn in sqlite, # so make sure to keep this actually last. - txn.execute( - "DROP TABLE events_to_purge" - ) + txn.execute("DROP TABLE events_to_purge") logger.info("[purge] done") @@ -2226,7 +2170,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore SELECT DISTINCT state_group FROM event_to_state_groups LEFT JOIN events_to_purge AS ep USING (event_id) WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % (",".join("?" for _ in current_search),) + """ % ( + ",".join("?" for _ in current_search), + ) txn.execute(sql, list(current_search)) referenced = set(sg for sg, in txn) @@ -2242,7 +2188,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore column="prev_state_group", iterable=current_search, keyvalues={}, - retcols=("prev_state_group", "state_group",), + retcols=("prev_state_group", "state_group"), ) prevs = set(row["state_group"] for row in rows) @@ -2279,16 +2225,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, - allow_none=True + allow_none=True, ) if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) - - def get_max_current_state_delta_stream_id(self): - return self._stream_id_gen.get_current_token() + defer.returnValue( + (int(res["topological_ordering"]), int(res["stream_ordering"])) + ) def get_all_updated_current_state_deltas(self, from_token, to_token, limit): def get_all_updated_current_state_deltas_txn(txn): @@ -2300,13 +2245,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore """ txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() + return self.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) -AllNewEventsResult = namedtuple("AllNewEventsResult", [ - "new_forward_events", "new_backfill_events", - "forward_ex_outliers", "backward_ex_outliers", -]) +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1716be529a..663991a9b6 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -21,8 +21,9 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.constants import EventFormatVersions, EventTypes +from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError +from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 # these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 @@ -70,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore): """ return self._simple_select_one_onecol( table="events", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, retcol="received_ts", desc="get_received_ts", ) @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False, check_room_id=None): + def get_event( + self, + event_id, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + allow_none=False, + check_room_id=None, + ): """Get an event from the database by event_id. Args: @@ -117,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue(event) @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def get_events( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): """Get events from the database Args: @@ -142,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def _get_events( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): if not event_ids: defer.returnValue([]) @@ -151,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids = set(event_ids) event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, + event_ids, allow_rejected=allow_rejected ) missing_events_ids = [e for e in event_ids if e not in event_entry_map] @@ -168,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore): # # _enqueue_events is a bit of a rubbish name but naming is hard. missing_events = yield self._enqueue_events( - missing_events_ids, - allow_rejected=allow_rejected, + missing_events_ids, allow_rejected=allow_rejected ) event_entry_map.update(missing_events) @@ -213,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore): ) expected_domain = get_domain_from_id(entry.event.sender) - if orig_sender and get_domain_from_id(orig_sender) == expected_domain: + if ( + orig_sender + and get_domain_from_id(orig_sender) == expected_domain + ): # This redaction event is allowed. Mark as not needing a # recheck. entry.event.internal_metadata.recheck_redaction = False @@ -266,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: ret = self._get_event_cache.get( - (event_id,), None, - update_metrics=update_metrics, + (event_id,), None, update_metrics=update_metrics ) if not ret: continue @@ -317,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore): with Measure(self._clock, "_fetch_event_list"): try: event_id_lists = list(zip(*event_list))[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] + event_ids = [item for sublist in event_id_lists for item in sublist] rows = self._new_transaction( - conn, "do_fetch", [], [], - self._fetch_event_rows, event_ids, + conn, "do_fetch", [], [], self._fetch_event_rows, event_ids ) - row_dict = { - r["event_id"]: r - for r in rows - } + row_dict = {r["event_id"]: r for r in rows} # We only want to resolve deferreds from the main thread def fire(lst, res): @@ -337,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore): if not d.called: try: with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) + d.callback([res[i] for i in ids if i in res]) except Exception: logger.exception("Failed to callback") + with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, row_dict) except Exception as e: @@ -370,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore): events_d = defer.Deferred() with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) + self._event_fetch_list.append((events, events_d)) self._event_fetch_lock.notify() @@ -384,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", - self.runWithConnection, - self._do_fetch, + "fetch_events", self.runWithConnection, self._do_fetch ) logger.debug("Loading %d events", len(events)) @@ -397,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self._get_event_from_row, - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - format_version=row["format_version"], - ) - for row in rows - ], - consumeErrors=True - )) + res = yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._get_event_from_row, + row["internal_metadata"], + row["json"], + row["redacts"], + rejected_reason=row["rejects"], + format_version=row["format_version"], + ) + for row in rows + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) + defer.returnValue({e.event.event_id: e for e in res if e}) def _fetch_event_rows(self, txn, events): rows = [] N = 200 for i in range(1 + len(events) // N): - evs = events[i * N:(i + 1) * N] + evs = events[i * N : (i + 1) * N] if not evs: break @@ -443,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore): return rows @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - format_version, rejected_reason=None): + def _get_event_from_row( + self, internal_metadata, js, redacted, format_version, rejected_reason=None + ): with Measure(self._clock, "_get_event_from_row"): d = json.loads(js) internal_metadata = json.loads(internal_metadata) @@ -483,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore): # Get the redaction event. because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, + redaction_id, check_redacted=False, allow_none=True ) if because: @@ -507,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore): redacted_event = None cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, + event=original_ev, redacted_event=redacted_event ) self._get_event_cache.prefill((original_ev.event_id,), cache_entry) @@ -544,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore): results = set() def have_seen_events_txn(txn, chunk): - sql = ( - "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" - % (",".join("?" * len(chunk)), ) + sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( + ",".join("?" * len(chunk)), ) txn.execute(sql, chunk) - for (event_id, ) in txn: + for (event_id,) in txn: results.add(event_id) # break the input up into chunks of 100 input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), - []): - yield self.runInteraction( - "have_seen_events", - have_seen_events_txn, - chunk, - ) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) defer.returnValue(results) def get_seen_events_with_rejections(self, event_ids): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 6ddcc909bf..b195dc66a0 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -35,10 +35,7 @@ class FilteringStore(SQLBaseStore): def_json = yield self._simple_select_one_onecol( table="user_filters", - keyvalues={ - "user_id": user_localpart, - "filter_id": filter_id, - }, + keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", allow_none=False, desc="get_user_filter", @@ -61,10 +58,7 @@ class FilteringStore(SQLBaseStore): if filter_id_response is not None: return filter_id_response[0] - sql = ( - "SELECT MAX(filter_id) FROM user_filters " - "WHERE user_id = ?" - ) + sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" txn.execute(sql, (user_localpart,)) max_id = txn.fetchone()[0] if max_id is None: diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 592d1b4c2a..dce6a43ac1 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -38,24 +38,22 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_update_one( table="groups", - keyvalues={ - "group_id": group_id, - }, - updatevalues={ - "join_policy": join_policy, - }, + keyvalues={"group_id": group_id}, + updatevalues={"join_policy": join_policy}, desc="set_group_join_policy", ) def get_group(self, group_id): return self._simple_select_one( table="groups", - keyvalues={ - "group_id": group_id, - }, + keyvalues={"group_id": group_id}, retcols=( - "name", "short_description", "long_description", - "avatar_url", "is_public", "join_policy", + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + "join_policy", ), allow_none=True, desc="get_group", @@ -64,16 +62,14 @@ class GroupServerStore(SQLBaseStore): def get_users_in_group(self, group_id, include_private=False): # TODO: Pagination - keyvalues = { - "group_id": group_id, - } + keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True return self._simple_select_list( table="group_users", keyvalues=keyvalues, - retcols=("user_id", "is_public", "is_admin",), + retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) @@ -82,9 +78,7 @@ class GroupServerStore(SQLBaseStore): return self._simple_select_onecol( table="group_invites", - keyvalues={ - "group_id": group_id, - }, + keyvalues={"group_id": group_id}, retcol="user_id", desc="get_invited_users_in_group", ) @@ -92,16 +86,14 @@ class GroupServerStore(SQLBaseStore): def get_rooms_in_group(self, group_id, include_private=False): # TODO: Pagination - keyvalues = { - "group_id": group_id, - } + keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True return self._simple_select_list( table="group_rooms", keyvalues=keyvalues, - retcols=("room_id", "is_public",), + retcols=("room_id", "is_public"), desc="get_rooms_in_group", ) @@ -110,10 +102,9 @@ class GroupServerStore(SQLBaseStore): Returns ([rooms], [categories]) """ + def _get_rooms_for_summary_txn(txn): - keyvalues = { - "group_id": group_id, - } + keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -162,18 +153,23 @@ class GroupServerStore(SQLBaseStore): } return rooms, categories - return self.runInteraction( - "get_rooms_for_summary", _get_rooms_for_summary_txn - ) + + return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): return self.runInteraction( - "add_room_to_summary", self._add_room_to_summary_txn, - group_id, room_id, category_id, order, is_public, + "add_room_to_summary", + self._add_room_to_summary_txn, + group_id, + room_id, + category_id, + order, + is_public, ) - def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, - is_public): + def _add_room_to_summary_txn( + self, txn, group_id, room_id, category_id, order, is_public + ): """Add (or update) room's entry in summary. Args: @@ -188,10 +184,7 @@ class GroupServerStore(SQLBaseStore): room_in_group = self._simple_select_one_onecol_txn( txn, table="group_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - }, + keyvalues={"group_id": group_id, "room_id": room_id}, retcol="room_id", allow_none=True, ) @@ -204,10 +197,7 @@ class GroupServerStore(SQLBaseStore): cat_exists = self._simple_select_one_onecol_txn( txn, table="group_room_categories", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - }, + keyvalues={"group_id": group_id, "category_id": category_id}, retcol="group_id", allow_none=True, ) @@ -218,22 +208,22 @@ class GroupServerStore(SQLBaseStore): cat_exists = self._simple_select_one_onecol_txn( txn, table="group_summary_room_categories", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - }, + keyvalues={"group_id": group_id, "category_id": category_id}, retcol="group_id", allow_none=True, ) if not cat_exists: # If not, add it with an order larger than all others - txn.execute(""" + txn.execute( + """ INSERT INTO group_summary_room_categories (group_id, category_id, cat_order) SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 FROM group_summary_room_categories WHERE group_id = ? AND category_id = ? - """, (group_id, category_id, group_id, category_id)) + """, + (group_id, category_id, group_id, category_id), + ) existing = self._simple_select_one_txn( txn, @@ -243,7 +233,7 @@ class GroupServerStore(SQLBaseStore): "room_id": room_id, "category_id": category_id, }, - retcols=("room_order", "is_public",), + retcols=("room_order", "is_public"), allow_none=True, ) @@ -253,13 +243,13 @@ class GroupServerStore(SQLBaseStore): UPDATE group_summary_rooms SET room_order = room_order + 1 WHERE group_id = ? AND category_id = ? AND room_order >= ? """ - txn.execute(sql, (group_id, category_id, order,)) + txn.execute(sql, (group_id, category_id, order)) elif not existing: sql = """ SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms WHERE group_id = ? AND category_id = ? """ - txn.execute(sql, (group_id, category_id,)) + txn.execute(sql, (group_id, category_id)) order, = txn.fetchone() if existing: @@ -312,29 +302,26 @@ class GroupServerStore(SQLBaseStore): def get_group_categories(self, group_id): rows = yield self._simple_select_list( table="group_room_categories", - keyvalues={ - "group_id": group_id, - }, + keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), desc="get_group_categories", ) - defer.returnValue({ - row["category_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + defer.returnValue( + { + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows } - for row in rows - }) + ) @defer.inlineCallbacks def get_group_category(self, group_id, category_id): category = yield self._simple_select_one( table="group_room_categories", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - }, + keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), desc="get_group_category", ) @@ -361,10 +348,7 @@ class GroupServerStore(SQLBaseStore): return self._simple_upsert( table="group_room_categories", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - }, + keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, insertion_values=insertion_values, desc="upsert_group_category", @@ -373,10 +357,7 @@ class GroupServerStore(SQLBaseStore): def remove_group_category(self, group_id, category_id): return self._simple_delete( table="group_room_categories", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - }, + keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", ) @@ -384,29 +365,26 @@ class GroupServerStore(SQLBaseStore): def get_group_roles(self, group_id): rows = yield self._simple_select_list( table="group_roles", - keyvalues={ - "group_id": group_id, - }, + keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), desc="get_group_roles", ) - defer.returnValue({ - row["role_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + defer.returnValue( + { + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows } - for row in rows - }) + ) @defer.inlineCallbacks def get_group_role(self, group_id, role_id): role = yield self._simple_select_one( table="group_roles", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), desc="get_group_role", ) @@ -433,10 +411,7 @@ class GroupServerStore(SQLBaseStore): return self._simple_upsert( table="group_roles", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, insertion_values=insertion_values, desc="upsert_group_role", @@ -445,21 +420,24 @@ class GroupServerStore(SQLBaseStore): def remove_group_role(self, group_id, role_id): return self._simple_delete( table="group_roles", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): return self.runInteraction( - "add_user_to_summary", self._add_user_to_summary_txn, - group_id, user_id, role_id, order, is_public, + "add_user_to_summary", + self._add_user_to_summary_txn, + group_id, + user_id, + role_id, + order, + is_public, ) - def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, - is_public): + def _add_user_to_summary_txn( + self, txn, group_id, user_id, role_id, order, is_public + ): """Add (or update) user's entry in summary. Args: @@ -474,10 +452,7 @@ class GroupServerStore(SQLBaseStore): user_in_group = self._simple_select_one_onecol_txn( txn, table="group_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, ) @@ -490,10 +465,7 @@ class GroupServerStore(SQLBaseStore): role_exists = self._simple_select_one_onecol_txn( txn, table="group_roles", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id}, retcol="group_id", allow_none=True, ) @@ -504,32 +476,28 @@ class GroupServerStore(SQLBaseStore): role_exists = self._simple_select_one_onecol_txn( txn, table="group_summary_roles", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id}, retcol="group_id", allow_none=True, ) if not role_exists: # If not, add it with an order larger than all others - txn.execute(""" + txn.execute( + """ INSERT INTO group_summary_roles (group_id, role_id, role_order) SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 FROM group_summary_roles WHERE group_id = ? AND role_id = ? - """, (group_id, role_id, group_id, role_id)) + """, + (group_id, role_id, group_id, role_id), + ) existing = self._simple_select_one_txn( txn, table="group_summary_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - "role_id": role_id, - }, - retcols=("user_order", "is_public",), + keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, + retcols=("user_order", "is_public"), allow_none=True, ) @@ -539,13 +507,13 @@ class GroupServerStore(SQLBaseStore): UPDATE group_summary_users SET user_order = user_order + 1 WHERE group_id = ? AND role_id = ? AND user_order >= ? """ - txn.execute(sql, (group_id, role_id, order,)) + txn.execute(sql, (group_id, role_id, order)) elif not existing: sql = """ SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users WHERE group_id = ? AND role_id = ? """ - txn.execute(sql, (group_id, role_id,)) + txn.execute(sql, (group_id, role_id)) order, = txn.fetchone() if existing: @@ -586,11 +554,7 @@ class GroupServerStore(SQLBaseStore): return self._simple_delete( table="group_summary_users", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", ) @@ -599,10 +563,9 @@ class GroupServerStore(SQLBaseStore): Returns ([users], [roles]) """ + def _get_users_for_summary_txn(txn): - keyvalues = { - "group_id": group_id, - } + keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -651,6 +614,7 @@ class GroupServerStore(SQLBaseStore): } return users, roles + return self.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -658,10 +622,7 @@ class GroupServerStore(SQLBaseStore): def is_user_in_group(self, user_id, group_id): return self._simple_select_one_onecol( table="group_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, desc="is_user_in_group", @@ -670,10 +631,7 @@ class GroupServerStore(SQLBaseStore): def is_user_admin_in_group(self, group_id, user_id): return self._simple_select_one_onecol( table="group_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", allow_none=True, desc="is_user_admin_in_group", @@ -684,10 +642,7 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_insert( table="group_invites", - values={ - "group_id": group_id, - "user_id": user_id, - }, + values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", ) @@ -696,10 +651,7 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_select_one_onecol( table="group_invites", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", desc="is_user_invited_to_local_group", allow_none=True, @@ -718,14 +670,12 @@ class GroupServerStore(SQLBaseStore): Returns an empty dict if the user is not join/invite/etc """ + def _get_users_membership_in_group_txn(txn): row = self._simple_select_one_txn( txn, table="group_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("is_admin", "is_public"), allow_none=True, ) @@ -740,27 +690,29 @@ class GroupServerStore(SQLBaseStore): row = self._simple_select_one_onecol_txn( txn, table="group_invites", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, ) if row: - return { - "membership": "invite", - } + return {"membership": "invite"} return {} return self.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn, + "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, - local_attestation=None, remote_attestation=None): + def add_user_to_group( + self, + group_id, + user_id, + is_admin=False, + is_public=True, + local_attestation=None, + remote_attestation=None, + ): """Add a user to the group server. Args: @@ -774,6 +726,7 @@ class GroupServerStore(SQLBaseStore): remote_attestation (dict): The attestation given to GS by remote server. Optional if the user and group are on the same server """ + def _add_user_to_group_txn(txn): self._simple_insert_txn( txn, @@ -789,10 +742,7 @@ class GroupServerStore(SQLBaseStore): self._simple_delete_txn( txn, table="group_invites", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: @@ -817,75 +767,52 @@ class GroupServerStore(SQLBaseStore): }, ) - return self.runInteraction( - "add_user_to_group", _add_user_to_group_txn - ) + return self.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): self._simple_delete_txn( txn, table="group_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_delete_txn( txn, table="group_invites", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_delete_txn( txn, table="group_attestations_renewals", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_delete_txn( txn, table="group_attestations_remote", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_delete_txn( txn, table="group_summary_users", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) + + return self.runInteraction( + "remove_user_from_group", _remove_user_from_group_txn + ) def add_room_to_group(self, group_id, room_id, is_public): return self._simple_insert( table="group_rooms", - values={ - "group_id": group_id, - "room_id": room_id, - "is_public": is_public, - }, + values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): return self._simple_update( table="group_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - }, - updatevalues={ - "is_public": is_public, - }, + keyvalues={"group_id": group_id, "room_id": room_id}, + updatevalues={"is_public": is_public}, desc="update_room_in_group_visibility", ) @@ -894,22 +821,17 @@ class GroupServerStore(SQLBaseStore): self._simple_delete_txn( txn, table="group_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - }, + keyvalues={"group_id": group_id, "room_id": room_id}, ) self._simple_delete_txn( txn, table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - }, + keyvalues={"group_id": group_id, "room_id": room_id}, ) + return self.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn, + "remove_room_from_group", _remove_room_from_group_txn ) def get_publicised_groups_for_user(self, user_id): @@ -917,11 +839,7 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_select_onecol( table="local_group_membership", - keyvalues={ - "user_id": user_id, - "membership": "join", - "is_publicised": True, - }, + keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", desc="get_publicised_groups_for_user", ) @@ -931,23 +849,23 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_update_one( table="local_group_membership", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, - updatevalues={ - "is_publicised": publicise, - }, - desc="update_group_publicity" + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"is_publicised": publicise}, + desc="update_group_publicity", ) @defer.inlineCallbacks - def register_user_group_membership(self, group_id, user_id, membership, - is_admin=False, content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + def register_user_group_membership( + self, + group_id, + user_id, + membership, + is_admin=False, + content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): """Registers that a local user is a member of a (local or remote) group. Args: @@ -962,15 +880,13 @@ class GroupServerStore(SQLBaseStore): remote_attestation (dict): If remote group then store the remote attestation from the group, else None. """ + def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? self._simple_delete_txn( txn, table="local_group_membership", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_insert_txn( txn, @@ -993,8 +909,10 @@ class GroupServerStore(SQLBaseStore): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps({"membership": membership, "content": content}), - } + "content": json.dumps( + {"membership": membership, "content": content} + ), + }, ) self._group_updates_stream_cache.entity_has_changed(user_id, next_id) @@ -1009,7 +927,7 @@ class GroupServerStore(SQLBaseStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": local_attestation["valid_until_ms"], - } + }, ) if remote_attestation: self._simple_insert_txn( @@ -1020,24 +938,18 @@ class GroupServerStore(SQLBaseStore): "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], "attestation_json": json.dumps(remote_attestation), - } + }, ) else: self._simple_delete_txn( txn, table="group_attestations_renewals", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) self._simple_delete_txn( txn, table="group_attestations_remote", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, ) return next_id @@ -1045,13 +957,15 @@ class GroupServerStore(SQLBaseStore): with self._group_updates_id_gen.get_next() as next_id: res = yield self.runInteraction( "register_user_group_membership", - _register_user_group_membership_txn, next_id, + _register_user_group_membership_txn, + next_id, ) defer.returnValue(res) @defer.inlineCallbacks - def create_group(self, group_id, user_id, name, avatar_url, short_description, - long_description,): + def create_group( + self, group_id, user_id, name, avatar_url, short_description, long_description + ): yield self._simple_insert( table="groups", values={ @@ -1066,12 +980,10 @@ class GroupServerStore(SQLBaseStore): ) @defer.inlineCallbacks - def update_group_profile(self, group_id, profile,): + def update_group_profile(self, group_id, profile): yield self._simple_update_one( table="groups", - keyvalues={ - "group_id": group_id, - }, + keyvalues={"group_id": group_id}, updatevalues=profile, desc="update_group_profile", ) @@ -1079,6 +991,7 @@ class GroupServerStore(SQLBaseStore): def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ + def _get_attestations_need_renewals_txn(txn): sql = """ SELECT group_id, user_id FROM group_attestations_renewals @@ -1086,6 +999,7 @@ class GroupServerStore(SQLBaseStore): """ txn.execute(sql, (valid_until_ms,)) return self.cursor_to_dict(txn) + return self.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -1095,13 +1009,8 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_update_one( table="group_attestations_renewals", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, - updatevalues={ - "valid_until_ms": attestation["valid_until_ms"], - }, + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) @@ -1110,13 +1019,10 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_update_one( table="group_attestations_remote", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation) + "attestation_json": json.dumps(attestation), }, desc="update_remote_attestion", ) @@ -1132,10 +1038,7 @@ class GroupServerStore(SQLBaseStore): """ return self._simple_delete( table="group_attestations_renewals", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", ) @@ -1146,10 +1049,7 @@ class GroupServerStore(SQLBaseStore): """ row = yield self._simple_select_one( table="group_attestations_remote", - keyvalues={ - "group_id": group_id, - "user_id": user_id, - }, + keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), desc="get_remote_attestation", allow_none=True, @@ -1164,10 +1064,7 @@ class GroupServerStore(SQLBaseStore): def get_joined_groups(self, user_id): return self._simple_select_onecol( table="local_group_membership", - keyvalues={ - "user_id": user_id, - "membership": "join", - }, + keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", desc="get_joined_groups", ) @@ -1181,7 +1078,7 @@ class GroupServerStore(SQLBaseStore): WHERE user_id = ? AND membership != 'leave' AND stream_id <= ? """ - txn.execute(sql, (user_id, now_token,)) + txn.execute(sql, (user_id, now_token)) return [ { "group_id": row[0], @@ -1191,14 +1088,15 @@ class GroupServerStore(SQLBaseStore): } for row in txn ] + return self.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn, + "get_all_groups_for_user", _get_all_groups_for_user_txn ) def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( - user_id, from_token, + user_id, from_token ) if not has_changed: return [] @@ -1210,21 +1108,25 @@ class GroupServerStore(SQLBaseStore): INNER JOIN local_group_membership USING (group_id, user_id) WHERE user_id = ? AND ? < stream_id AND stream_id <= ? """ - txn.execute(sql, (user_id, from_token, to_token,)) - return [{ - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": json.loads(content_json), - } for group_id, membership, gtype, content_json in txn] + txn.execute(sql, (user_id, from_token, to_token)) + return [ + { + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } + for group_id, membership, gtype, content_json in txn + ] + return self.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn, + "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) def get_all_groups_changes(self, from_token, to_token, limit): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_any_entity_changed( - from_token, + from_token ) if not has_changed: return [] @@ -1236,17 +1138,52 @@ class GroupServerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit,)) - return [( - stream_id, - group_id, - user_id, - gtype, - json.loads(content_json), - ) for stream_id, group_id, user_id, gtype, content_json in txn] + txn.execute(sql, (from_token, to_token, limit)) + return [ + (stream_id, group_id, user_id, gtype, json.loads(content_json)) + for stream_id, group_id, user_id, gtype, content_json in txn + ] + return self.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn, + "get_all_groups_changes", _get_all_groups_changes_txn ) def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() + + def delete_group(self, group_id): + """Deletes a group fully from the database. + + Args: + group_id (str) + + Returns: + Deferred + """ + + def _delete_group_txn(txn): + tables = [ + "groups", + "group_users", + "group_invites", + "group_rooms", + "group_summary_rooms", + "group_summary_room_categories", + "group_room_categories", + "group_summary_users", + "group_summary_roles", + "group_roles", + "group_attestations_renewals", + "group_attestations_remote", + ] + + for table in tables: + self._simple_delete_txn( + txn, + table=table, + keyvalues={"group_id": group_id}, + ) + + return self.runInteraction( + "delete_group", _delete_group_txn + ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 8af17921e3..7036541792 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib +import itertools import logging import six from signedjson.key import decode_verify_key_bytes -import OpenSSL -from twisted.internet import defer - -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList from ._base import SQLBaseStore @@ -38,93 +37,56 @@ else: class KeyStore(SQLBaseStore): - """Persistence for signature verification keys and tls X.509 certificates + """Persistence for signature verification keys """ - @defer.inlineCallbacks - def get_server_certificate(self, server_name): - """Retrieve the TLS X.509 certificate for the given server + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ Args: - server_name (bytes): The name of the server. + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + Returns: - (OpenSSL.crypto.X509): The tls certificate. + Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: + map from (server_name, key_id) -> VerifyKey, or None if the key is + unknown """ - tls_certificate_bytes, = yield self._simple_select_one( - table="server_tls_certificates", - keyvalues={"server_name": server_name}, - retcols=("tls_certificate",), - desc="get_server_certificate", - ) - tls_certificate = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes, - ) - defer.returnValue(tls_certificate) + keys = {} - def store_server_certificate(self, server_name, from_server, time_now_ms, - tls_certificate): - """Stores the TLS X.509 certificate for the given server - Args: - server_name (str): The name of the server. - from_server (str): Where the certificate was looked up - time_now_ms (int): The time now in milliseconds - tls_certificate (OpenSSL.crypto.X509): The X.509 certificate. - """ - tls_certificate_bytes = OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_ASN1, tls_certificate - ) - fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() - return self._simple_upsert( - table="server_tls_certificates", - keyvalues={ - "server_name": server_name, - "fingerprint": fingerprint, - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "tls_certificate": db_binary_type(tls_certificate_bytes), - }, - desc="store_server_certificate", - ) + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" - @cachedInlineCallbacks() - def _get_server_verify_key(self, server_name, key_id): - verify_key_bytes = yield self._simple_select_one_onecol( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - }, - retcol="verify_key", - desc="_get_server_verify_key", - allow_none=True, - ) + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key FROM server_signature_keys " + "WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) - if verify_key_bytes: - defer.returnValue(decode_verify_key_bytes( - key_id, bytes(verify_key_bytes) - )) + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - @defer.inlineCallbacks - def get_server_verify_keys(self, server_name, key_ids): - """Retrieve the NACL verification key for a given server for the given - key_ids - Args: - server_name (str): The name of the server. - key_ids (iterable[str]): key_ids to try and look up. - Returns: - Deferred: resolves to dict[str, VerifyKey]: map from - key_id to verification key. - """ - keys = {} - for key_id in key_ids: - key = yield self._get_server_verify_key(server_name, key_id) - if key: - keys[key_id] = key - defer.returnValue(keys) - - def store_server_verify_key(self, server_name, from_server, time_now_ms, - verify_key): + for row in txn: + server_name, key_id, key_bytes = row + keys[(server_name, key_id)] = decode_verify_key_bytes( + key_id, bytes(key_bytes) + ) + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_key( + self, server_name, from_server, time_now_ms, verify_key + ): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. @@ -139,25 +101,25 @@ class KeyStore(SQLBaseStore): self._simple_upsert_txn( txn, table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - }, + keyvalues={"server_name": server_name, "key_id": key_id}, values={ "from_server": from_server, "ts_added_ms": time_now_ms, "verify_key": db_binary_type(verify_key.encode()), }, ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). txn.call_after( - self._get_server_verify_key.invalidate, - (server_name, key_id) + self._get_server_verify_key.invalidate, ((server_name, key_id),) ) return self.runInteraction("store_server_verify_key", _txn) - def store_server_keys_json(self, server_name, key_id, from_server, - ts_now_ms, ts_expires_ms, key_json_bytes): + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the @@ -197,9 +159,10 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Dict mapping (server_name, key_id, source) triplets to dicts with - "ts_valid_until_ms" and "key_json" keys. + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts """ + def _get_server_keys_json_txn(txn): results = {} for server_name, key_id, from_server in server_keys: @@ -222,6 +185,5 @@ class KeyStore(SQLBaseStore): ) results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn - ) + + return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index e6cdbb0545..3ecf47e7a7 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -38,15 +38,27 @@ class MediaRepositoryStore(BackgroundUpdateStore): "local_media_repository", {"media_id": media_id}, ( - "media_type", "media_length", "upload_name", "created_ts", - "quarantined_by", "url_cache", + "media_type", + "media_length", + "upload_name", + "created_ts", + "quarantined_by", + "url_cache", ), allow_none=True, desc="get_local_media", ) - def store_local_media(self, media_id, media_type, time_now_ms, upload_name, - media_length, user_id, url_cache=None): + def store_local_media( + self, + media_id, + media_type, + time_now_ms, + upload_name, + media_length, + user_id, + url_cache=None, + ): return self._simple_insert( "local_media_repository", { @@ -66,6 +78,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): Returns: None if the URL isn't cached. """ + def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( @@ -92,16 +105,25 @@ class MediaRepositoryStore(BackgroundUpdateStore): if not row: return None - return dict(zip(( - 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' - ), row)) + return dict( + zip( + ( + 'response_code', + 'etag', + 'expires_ts', + 'og', + 'media_id', + 'download_ts', + ), + row, + ) + ) - return self.runInteraction( - "get_url_cache", get_url_cache_txn - ) + return self.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, - download_ts): + def store_url_cache( + self, url, response_code, etag, expires_ts, og, media_id, download_ts + ): return self._simple_insert( "local_media_repository_url_cache", { @@ -121,15 +143,24 @@ class MediaRepositoryStore(BackgroundUpdateStore): "local_media_repository_thumbnails", {"media_id": media_id}, ( - "thumbnail_width", "thumbnail_height", "thumbnail_method", - "thumbnail_type", "thumbnail_length", + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", ), desc="get_local_media_thumbnails", ) - def store_local_thumbnail(self, media_id, thumbnail_width, - thumbnail_height, thumbnail_type, - thumbnail_method, thumbnail_length): + def store_local_thumbnail( + self, + media_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): return self._simple_insert( "local_media_repository_thumbnails", { @@ -148,16 +179,27 @@ class MediaRepositoryStore(BackgroundUpdateStore): "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( - "media_type", "media_length", "upload_name", "created_ts", - "filesystem_id", "quarantined_by", + "media_type", + "media_length", + "upload_name", + "created_ts", + "filesystem_id", + "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", ) - def store_cached_remote_media(self, origin, media_id, media_type, - media_length, time_now_ms, upload_name, - filesystem_id): + def store_cached_remote_media( + self, + origin, + media_id, + media_type, + media_length, + time_now_ms, + upload_name, + filesystem_id, + ): return self._simple_insert( "remote_media_cache", { @@ -181,26 +223,27 @@ class MediaRepositoryStore(BackgroundUpdateStore): remote_media (iterable[(str, str)]): Set of (server_name, media_id) time_ms: Current time in milliseconds """ + def update_cache_txn(txn): sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" " WHERE media_origin = ? AND media_id = ?" ) - txn.executemany(sql, ( - (time_ms, media_origin, media_id) - for media_origin, media_id in remote_media - )) + txn.executemany( + sql, + ( + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + ), + ) sql = ( "UPDATE local_media_repository SET last_access_ts = ?" " WHERE media_id = ?" ) - txn.executemany(sql, ( - (time_ms, media_id) - for media_id in local_media - )) + txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) return self.runInteraction("update_cached_last_access_time", update_cache_txn) @@ -209,16 +252,27 @@ class MediaRepositoryStore(BackgroundUpdateStore): "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( - "thumbnail_width", "thumbnail_height", "thumbnail_method", - "thumbnail_type", "thumbnail_length", "filesystem_id", + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", ), desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail(self, origin, media_id, filesystem_id, - thumbnail_width, thumbnail_height, - thumbnail_type, thumbnail_method, - thumbnail_length): + def store_remote_media_thumbnail( + self, + origin, + media_id, + filesystem_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): return self._simple_insert( "remote_media_cache_thumbnails", { @@ -250,17 +304,14 @@ class MediaRepositoryStore(BackgroundUpdateStore): self._simple_delete_txn( txn, "remote_media_cache", - keyvalues={ - "media_origin": media_origin, "media_id": media_id - }, + keyvalues={"media_origin": media_origin, "media_id": media_id}, ) self._simple_delete_txn( txn, "remote_media_cache_thumbnails", - keyvalues={ - "media_origin": media_origin, "media_id": media_id - }, + keyvalues={"media_origin": media_origin, "media_id": media_id}, ) + return self.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): @@ -281,10 +332,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): if len(media_ids) == 0: return - sql = ( - "DELETE FROM local_media_repository_url_cache" - " WHERE media_id = ?" - ) + sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?" def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) @@ -304,7 +352,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): return [row[0] for row in txn] return self.runInteraction( - "get_url_cache_media_before", _get_url_cache_media_before_txn, + "get_url_cache_media_before", _get_url_cache_media_before_txn ) def delete_url_cache_media(self, media_ids): @@ -312,20 +360,14 @@ class MediaRepositoryStore(BackgroundUpdateStore): return def _delete_url_cache_media_txn(txn): - sql = ( - "DELETE FROM local_media_repository" - " WHERE media_id = ?" - ) + sql = "DELETE FROM local_media_repository" " WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) - sql = ( - "DELETE FROM local_media_repository_thumbnails" - " WHERE media_id = ?" - ) + sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) return self.runInteraction( - "delete_url_cache_media", _delete_url_cache_media_txn, + "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 9e7e09b8c1..8aa8abc470 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -35,9 +35,12 @@ class MonthlyActiveUsersStore(SQLBaseStore): self.reserved_users = () # Do not add more reserved users than the total allowable number self._new_transaction( - dbconn, "initialise_mau_threepids", [], [], + dbconn, + "initialise_mau_threepids", + [], + [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], + hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): @@ -51,10 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): reserved_user_list = [] for tp in threepids: - user_id = self.get_user_id_by_threepid_txn( - txn, - tp["medium"], tp["address"] - ) + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) if user_id: is_support = self.is_support_user_txn(txn, user_id) @@ -62,9 +62,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: - logger.warning( - "mau limit reserved threepid %s not found in db" % tp - ) + logger.warning("mau limit reserved threepid %s not found in db" % tp) self.reserved_users = tuple(reserved_user_list) @defer.inlineCallbacks @@ -75,12 +73,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): Returns: Deferred[] """ + def _reap_users(txn): # Purge stale users - thirty_days_ago = ( - int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - ) + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) query_args = [thirty_days_ago] base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" @@ -158,6 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql) count, = txn.fetchone() return count + return self.runInteraction("count_users", _count_users) @defer.inlineCallbacks @@ -198,14 +196,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): return yield self.runInteraction( - "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, - user_id + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) user_in_mau = self.user_last_seen_monthly_active.cache.get( - (user_id,), - None, - update_metrics=False + (user_id,), None, update_metrics=False ) if user_in_mau is None: self.get_monthly_active_count.invalidate(()) @@ -247,12 +242,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): is_insert = self._simple_upsert_txn( txn, table="monthly_active_users", - keyvalues={ - "user_id": user_id, - }, - values={ - "timestamp": int(self._clock.time_msec()), - }, + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, ) return is_insert @@ -268,15 +259,13 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return(self._simple_select_one_onecol( + return self._simple_select_one_onecol( table="monthly_active_users", - keyvalues={ - "user_id": user_id, - }, + keyvalues={"user_id": user_id}, retcol="timestamp", allow_none=True, desc="user_last_seen_monthly_active", - )) + ) @defer.inlineCallbacks def populate_monthly_active_users(self, user_id): diff --git a/synapse/storage/openid.py b/synapse/storage/openid.py index 5dabb607bd..b3318045ee 100644 --- a/synapse/storage/openid.py +++ b/synapse/storage/openid.py @@ -10,7 +10,7 @@ class OpenIdStore(SQLBaseStore): "ts_valid_until_ms": ts_valid_until_ms, "user_id": user_id, }, - desc="insert_open_id_token" + desc="insert_open_id_token", ) def get_user_id_for_open_id_token(self, token, ts_now_ms): @@ -27,6 +27,5 @@ class OpenIdStore(SQLBaseStore): return None else: return rows[0][0] - return self.runInteraction( - "get_user_id_for_token", get_user_id_for_token_txn - ) + + return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index fa36daac52..c1711bc8bd 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 53 +SCHEMA_VERSION = 54 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -143,10 +143,9 @@ def _setup_new_database(cur, database_engine): cur.execute( database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" + "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)" ), - (max_current_ver, False,) + (max_current_ver, False), ) _upgrade_existing_database( @@ -160,8 +159,15 @@ def _setup_new_database(cur, database_engine): ) -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config, is_empty=False): +def _upgrade_existing_database( + cur, + current_version, + applied_delta_files, + upgraded, + database_engine, + config, + is_empty=False, +): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -209,8 +215,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, if current_version > SCHEMA_VERSION: raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" + "Cannot use this database as it is too " + + "new for the server to understand" ) start_ver = current_version @@ -239,20 +245,14 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, if relative_path in applied_delta_files: continue - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) + absolute_path = os.path.join(dir_path, "schema", "delta", relative_path) root_name, ext = os.path.splitext(file_name) if ext == ".py": # This is a python upgrade module. We need to import into some # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) + module_name = "synapse.storage.v%d_%s" % (v, root_name) with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) + module = imp.load_source(module_name, absolute_path, python_file) logger.info("Running script %s", relative_path) module.run_create(cur, database_engine) if not is_empty: @@ -269,8 +269,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, else: # Not a valid delta file. logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", + "Found directory entry that did not end in .py or" " .sql: %s", relative_path, ) continue @@ -278,19 +277,17 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, # Mark as done. cur.execute( database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", + "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)" ), - (v, relative_path) + (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", + "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)" ), - (v, True) + (v, True), ) @@ -308,7 +305,7 @@ def _apply_module_schemas(txn, database_engine, config): continue modname = ".".join((mod.__module__, mod.__name__)) _apply_module_schema_files( - txn, database_engine, modname, mod.get_db_schema_files(), + txn, database_engine, modname, mod.get_db_schema_files() ) @@ -326,7 +323,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) database_engine.convert_param_style( "SELECT file FROM applied_module_schemas WHERE module_name = ?" ), - (modname,) + (modname,), ) applied_deltas = set(d for d, in cur) for (name, stream) in names_and_streams: @@ -336,7 +333,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) root_name, ext = os.path.splitext(name) if ext != '.sql': raise PrepareDatabaseException( - "only .sql files are currently supported for module schemas", + "only .sql files are currently supported for module schemas" ) logger.info("applying schema %s for %s", name, modname) @@ -346,10 +343,9 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file)" - " VALUES (?,?)", + "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)" ), - (modname, name) + (modname, name), ) @@ -386,10 +382,7 @@ def get_statements(f): statements = line.split(";") # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) + first_statement = "%s %s" % (statement_buffer.strip(), statements[0].strip()) statements[0] = first_statement # Every entry, except the last, is a full statement @@ -409,9 +402,7 @@ def executescript(txn, schema_path): def _get_or_create_schema_state(txn, database_engine): # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) + schema_path = os.path.join(dir_path, "schema", "schema_version.sql") executescript(txn, schema_path) txn.execute("SELECT version, upgraded FROM schema_version") @@ -424,7 +415,7 @@ def _get_or_create_schema_state(txn, database_engine): database_engine.convert_param_style( "SELECT file FROM applied_schema_deltas WHERE version >= ?" ), - (current_version,) + (current_version,), ) applied_deltas = [d for d, in txn] return current_version, applied_deltas, upgraded diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index a0c7a0dc87..42ec8c6bb8 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -19,15 +19,25 @@ from twisted.internet import defer from synapse.api.constants import PresenceState from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from ._base import SQLBaseStore -class UserPresenceState(namedtuple("UserPresenceState", - ("user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active"))): +class UserPresenceState( + namedtuple( + "UserPresenceState", + ( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + ) +): """Represents the current presence state of the user. user_id (str) @@ -75,22 +85,21 @@ class PresenceStore(SQLBaseStore): with stream_ordering_manager as stream_orderings: yield self.runInteraction( "update_presence", - self._update_presence_txn, stream_orderings, presence_states, + self._update_presence_txn, + stream_orderings, + presence_states, ) - defer.returnValue(( - stream_orderings[-1], self._presence_id_gen.get_current_token() - )) + defer.returnValue( + (stream_orderings[-1], self._presence_id_gen.get_current_token()) + ) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( - self.presence_stream_cache.entity_has_changed, - state.user_id, stream_id, - ) - txn.call_after( - self._get_presence_for_user.invalidate, (state.user_id,) + self.presence_stream_cache.entity_has_changed, state.user_id, stream_id ) + txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows self._simple_insert_many_txn( @@ -113,18 +122,13 @@ class PresenceStore(SQLBaseStore): # Delete old rows to stop database from getting really big sql = ( - "DELETE FROM presence_stream WHERE" - " stream_id < ?" - " AND user_id IN (%s)" + "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)" ) for states in batch_iter(presence_states, 50): args = [stream_id] args.extend(s.user_id for s in states) - txn.execute( - sql % (",".join("?" for _ in states),), - args - ) + txn.execute(sql % (",".join("?" for _ in states),), args) def get_all_presence_updates(self, last_id, current_id): if last_id == current_id: @@ -149,8 +153,12 @@ class PresenceStore(SQLBaseStore): def _get_presence_for_user(self, user_id): raise NotImplementedError() - @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids", - num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( table="presence_stream", @@ -180,8 +188,10 @@ class PresenceStore(SQLBaseStore): def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( table="presence_allow_inbound", - values={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, + values={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, desc="allow_presence_visible", or_ignore=True, ) @@ -189,89 +199,9 @@ class PresenceStore(SQLBaseStore): def disallow_presence_visible(self, observed_localpart, observer_userid): return self._simple_delete_one( table="presence_allow_inbound", - keyvalues={"observed_user_id": observed_localpart, - "observer_user_id": observer_userid}, + keyvalues={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, desc="disallow_presence_visible", ) - - def add_presence_list_pending(self, observer_localpart, observed_userid): - return self._simple_insert( - table="presence_list", - values={"user_id": observer_localpart, - "observed_user_id": observed_userid, - "accepted": False}, - desc="add_presence_list_pending", - ) - - def set_presence_list_accepted(self, observer_localpart, observed_userid): - def update_presence_list_txn(txn): - result = self._simple_update_one_txn( - txn, - table="presence_list", - keyvalues={ - "user_id": observer_localpart, - "observed_user_id": observed_userid - }, - updatevalues={"accepted": True}, - ) - - self._invalidate_cache_and_stream( - txn, self.get_presence_list_accepted, (observer_localpart,) - ) - self._invalidate_cache_and_stream( - txn, self.get_presence_list_observers_accepted, (observed_userid,) - ) - - return result - - return self.runInteraction( - "set_presence_list_accepted", update_presence_list_txn, - ) - - def get_presence_list(self, observer_localpart, accepted=None): - if accepted: - return self.get_presence_list_accepted(observer_localpart) - else: - keyvalues = {"user_id": observer_localpart} - if accepted is not None: - keyvalues["accepted"] = accepted - - return self._simple_select_list( - table="presence_list", - keyvalues=keyvalues, - retcols=["observed_user_id", "accepted"], - desc="get_presence_list", - ) - - @cached() - def get_presence_list_accepted(self, observer_localpart): - return self._simple_select_list( - table="presence_list", - keyvalues={"user_id": observer_localpart, "accepted": True}, - retcols=["observed_user_id", "accepted"], - desc="get_presence_list_accepted", - ) - - @cachedInlineCallbacks() - def get_presence_list_observers_accepted(self, observed_userid): - user_localparts = yield self._simple_select_onecol( - table="presence_list", - keyvalues={"observed_user_id": observed_userid, "accepted": True}, - retcol="user_id", - desc="get_presence_list_accepted", - ) - - defer.returnValue([ - "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts - ]) - - @defer.inlineCallbacks - def del_presence_list(self, observer_localpart, observed_userid): - yield self._simple_delete_one( - table="presence_list", - keyvalues={"user_id": observer_localpart, - "observed_user_id": observed_userid}, - desc="del_presence_list", - ) - self.get_presence_list_accepted.invalidate((observer_localpart,)) - self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 88b50f33b5..aeec2f57c4 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -41,8 +41,7 @@ class ProfileWorkerStore(SQLBaseStore): defer.returnValue( ProfileInfo( - avatar_url=profile['avatar_url'], - display_name=profile['displayname'], + avatar_url=profile['avatar_url'], display_name=profile['displayname'] ) ) @@ -66,16 +65,14 @@ class ProfileWorkerStore(SQLBaseStore): return self._simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url",), + retcols=("displayname", "avatar_url"), allow_none=True, desc="get_from_remote_profile_cache", ) def create_profile(self, user_localpart): return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", + table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): @@ -141,6 +138,7 @@ class ProfileStore(ProfileWorkerStore): def get_remote_profile_cache_entries_that_expire(self, last_checked): """Get all users who haven't been checked since `last_checked` """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): sql = """ SELECT user_id, displayname, avatar_url diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 4b8438c3e9..9e406baafa 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -57,11 +57,13 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRulesWorkerStore(ApplicationServiceWorkerStore, - ReceiptsWorkerStore, - PusherWorkerStore, - RoomMemberWorkerStore, - SQLBaseStore): +class PushRulesWorkerStore( + ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore, +): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ @@ -74,14 +76,16 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", + db_conn, + "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, + "PushRulesStreamChangeCache", + push_rules_id, prefilled_cache=push_rules_prefill, ) @@ -98,19 +102,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", - keyvalues={ - "user_name": user_id, - }, + keyvalues={"user_name": user_id}, retcols=( - "user_name", "rule_id", "priority_class", "priority", - "conditions", "actions", + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", ), desc="get_push_rules_enabled_for_user", ) - rows.sort( - key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) - ) + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = yield self.get_push_rules_enabled_for_user(user_id) @@ -122,22 +126,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", - keyvalues={ - 'user_name': user_id - }, - retcols=( - "user_name", "rule_id", "enabled", - ), + keyvalues={'user_name': user_id}, + retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) - defer.returnValue({ - r['rule_id']: False if r['enabled'] == 0 else True for r in results - }) + defer.returnValue( + {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + ) def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return defer.succeed(False) else: + def have_push_rules_changed_txn(txn): sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" @@ -146,20 +147,22 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, txn.execute(sql, (user_id, last_id)) count, = txn.fetchone() return bool(count) + return self.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) - @cachedList(cached_method_name="get_push_rules_for_user", - list_name="user_ids", num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) def bulk_get_push_rules(self, user_ids): if not user_ids: defer.returnValue({}) - results = { - user_id: [] - for user_id in user_ids - } + results = {user_id: [] for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules", @@ -169,9 +172,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, desc="bulk_get_push_rules", ) - rows.sort( - key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) - ) + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row['user_name'], []).append(row) @@ -179,16 +180,12 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}) - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) defer.returnValue(results) @defer.inlineCallbacks - def move_push_rule_from_room_to_room( - self, new_room_id, user_id, rule, - ): + def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): """Move a single push rule from one room to another for a specific user. Args: @@ -219,7 +216,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, @defer.inlineCallbacks def move_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id, + self, old_room_id, new_room_id, user_id ): """Move all of the push rules from one room to another for a specific user. @@ -236,11 +233,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # delete them from the old room for rule in user_push_rules: conditions = rule.get("conditions", []) - if any((c.get("key") == "room_id" and - c.get("pattern") == old_room_id) for c in conditions): - self.move_push_rule_from_room_to_room( - new_room_id, user_id, rule, - ) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): @@ -259,8 +256,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, - cache_context, event=None): + def _bulk_get_push_rules_for_room( + self, room_id, state_group, current_state_ids, cache_context, event=None + ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -273,7 +271,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # sent a read receipt into the room. users_in_room = yield self._get_joined_users_from_context( - room_id, state_group, current_state_ids, + room_id, + state_group, + current_state_ids, on_invalidate=cache_context.invalidate, event=event, ) @@ -282,7 +282,8 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. local_users_in_room = set( - u for u in users_in_room + u + for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) ) @@ -290,15 +291,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, - on_invalidate=cache_context.invalidate, + local_users_in_room, on_invalidate=cache_context.invalidate ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) # any users with pushers must be ours: they have pushers @@ -307,29 +307,30 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate, + user_ids, on_invalidate=cache_context.invalidate ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} defer.returnValue(rules_by_user) - @cachedList(cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: defer.returnValue({}) - results = { - user_id: {} - for user_id in user_ids - } + results = {user_id: {} for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, - retcols=("user_name", "rule_id", "enabled",), + retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: @@ -341,8 +342,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( - self, user_id, rule_id, priority_class, conditions, actions, - before=None, after=None + self, + user_id, + rule_id, + priority_class, + conditions, + actions, + before=None, + after=None, ): conditions_json = json.dumps(conditions) actions_json = json.dumps(actions) @@ -352,20 +359,41 @@ class PushRuleStore(PushRulesWorkerStore): yield self.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, - stream_id, event_stream_ordering, user_id, rule_id, priority_class, - conditions_json, actions_json, before, after, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, ) else: yield self.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, - stream_id, event_stream_ordering, user_id, rule_id, priority_class, - conditions_json, actions_json, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, ) def _add_push_rule_relative_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, - conditions_json, actions_json, before, after + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, ): # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. @@ -376,10 +404,7 @@ class PushRuleStore(PushRulesWorkerStore): res = self._simple_select_one_txn( txn, table="push_rules", - keyvalues={ - "user_name": user_id, - "rule_id": relative_to_rule, - }, + keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, retcols=["priority_class", "priority"], allow_none=True, ) @@ -416,13 +441,27 @@ class PushRuleStore(PushRulesWorkerStore): txn.execute(sql, (user_id, priority_class, new_rule_priority)) self._upsert_push_rule_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, - new_rule_priority, conditions_json, actions_json, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_rule_priority, + conditions_json, + actions_json, ) def _add_push_rule_highest_priority_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, - conditions_json, actions_json + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, ): # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. @@ -443,13 +482,28 @@ class PushRuleStore(PushRulesWorkerStore): self._upsert_push_rule_txn( txn, - stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, - conditions_json, actions_json, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_prio, + conditions_json, + actions_json, ) def _upsert_push_rule_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, - priority, conditions_json, actions_json, update_stream=True + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + conditions_json, + actions_json, + update_stream=True, ): """Specialised version of _simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes @@ -461,10 +515,10 @@ class PushRuleStore(PushRulesWorkerStore): " WHERE user_name = ? AND rule_id = ?" ) - txn.execute(sql, ( - priority_class, priority, conditions_json, actions_json, - user_id, rule_id, - )) + txn.execute( + sql, + (priority_class, priority, conditions_json, actions_json, user_id, rule_id), + ) if txn.rowcount == 0: # We didn't update a row with the given rule_id so insert one @@ -486,14 +540,18 @@ class PushRuleStore(PushRulesWorkerStore): if update_stream: self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, op="ADD", data={ "priority_class": priority_class, "priority": priority, "conditions": conditions_json, "actions": actions_json, - } + }, ) @defer.inlineCallbacks @@ -507,22 +565,23 @@ class PushRuleStore(PushRulesWorkerStore): user_id (str): The matrix ID of the push rule owner rule_id (str): The rule_id of the rule to be deleted """ + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): self._simple_delete_one_txn( - txn, - "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, + txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id} ) self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, - op="DELETE" + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids yield self.runInteraction( - "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering + "delete_push_rule", + delete_push_rule_txn, + stream_id, + event_stream_ordering, ) @defer.inlineCallbacks @@ -532,7 +591,11 @@ class PushRuleStore(PushRulesWorkerStore): yield self.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, - stream_id, event_stream_ordering, user_id, rule_id, enabled + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, ) def _set_push_rule_enabled_txn( @@ -548,8 +611,12 @@ class PushRuleStore(PushRulesWorkerStore): ) self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, - op="ENABLE" if enabled else "DISABLE" + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ENABLE" if enabled else "DISABLE", ) @defer.inlineCallbacks @@ -563,9 +630,16 @@ class PushRuleStore(PushRulesWorkerStore): priority_class = -1 priority = 1 self._upsert_push_rule_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, - priority_class, priority, "[]", actions_json, - update_stream=False + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + "[]", + actions_json, + update_stream=False, ) else: self._simple_update_one_txn( @@ -576,15 +650,22 @@ class PushRuleStore(PushRulesWorkerStore): ) self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, - op="ACTIONS", data={"actions": actions_json} + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ACTIONS", + data={"actions": actions_json}, ) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids yield self.runInteraction( - "set_push_rule_actions", set_push_rule_actions_txn, - stream_id, event_stream_ordering + "set_push_rule_actions", + set_push_rule_actions_txn, + stream_id, + event_stream_ordering, ) def _insert_push_rules_update_txn( @@ -602,12 +683,8 @@ class PushRuleStore(PushRulesWorkerStore): self._simple_insert_txn(txn, "push_rules_stream", values=values) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) + txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) + txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) txn.call_after( self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) @@ -627,6 +704,7 @@ class PushRuleStore(PushRulesWorkerStore): ) txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() + return self.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 134297e284..1567e1df48 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -47,7 +47,9 @@ class PusherWorkerStore(SQLBaseStore): except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.args[0], + r['id'], + dataJson, + e.args[0], ) pass @@ -64,20 +66,16 @@ class PusherWorkerStore(SQLBaseStore): defer.returnValue(ret is not None) def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({ - "app_id": app_id, - "pushkey": pushkey, - }) + return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({ - "user_name": user_id, - }) + return self.get_pushers_by({"user_name": user_id}) @defer.inlineCallbacks def get_pushers_by(self, keyvalues): ret = yield self._simple_select_list( - "pushers", keyvalues, + "pushers", + keyvalues, [ "id", "user_name", @@ -94,7 +92,8 @@ class PusherWorkerStore(SQLBaseStore): "last_stream_ordering", "last_success", "failing_since", - ], desc="get_pushers_by" + ], + desc="get_pushers_by", ) defer.returnValue(self._decode_pushers_rows(ret)) @@ -135,6 +134,7 @@ class PusherWorkerStore(SQLBaseStore): deleted = txn.fetchall() return (updated, deleted) + return self.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -177,6 +177,7 @@ class PusherWorkerStore(SQLBaseStore): results.sort() # Sort so that they're ordered by stream id return results + return self.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -186,15 +187,19 @@ class PusherWorkerStore(SQLBaseStore): # This only exists for the cachedList decorator raise NotImplementedError() - @cachedList(cached_method_name="get_if_user_has_pusher", - list_name="user_ids", num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="get_if_user_has_pusher", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) def get_if_users_have_pushers(self, user_ids): rows = yield self._simple_select_many_batch( table='pushers', column='user_name', iterable=user_ids, retcols=['user_name'], - desc='get_if_users_have_pushers' + desc='get_if_users_have_pushers', ) result = {user_id: False for user_id in user_ids} @@ -208,20 +213,27 @@ class PusherStore(PusherWorkerStore): return self._pushers_id_gen.get_current_token() @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, kind, app_id, - app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data, last_stream_ordering, - profile_tag=""): + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + lang, + data, + last_stream_ordering, + profile_tag="", + ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so _simple_upsert will retry yield self._simple_upsert( table="pushers", - keyvalues={ - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ "access_token": access_token, "kind": kind, @@ -247,7 +259,8 @@ class PusherStore(PusherWorkerStore): yield self.runInteraction( "add_pusher", self._invalidate_cache_and_stream, - self.get_if_user_has_pusher, (user_id,) + self.get_if_user_has_pusher, + (user_id,), ) @defer.inlineCallbacks @@ -260,7 +273,7 @@ class PusherStore(PusherWorkerStore): self._simple_delete_one_txn( txn, "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, ) # it's possible for us to end up with duplicate rows for @@ -278,13 +291,12 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction( - "delete_pusher", delete_pusher_txn, stream_id - ) + yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) @defer.inlineCallbacks - def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id, - last_stream_ordering): + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, @@ -293,23 +305,21 @@ class PusherStore(PusherWorkerStore): ) @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey, - user_id, - last_stream_ordering, - last_success): + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, { 'last_stream_ordering': last_stream_ordering, - 'last_success': last_success + 'last_success': last_success, }, desc="update_pusher_last_stream_ordering_and_success", ) @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, - failing_since): + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): yield self._simple_update_one( "pushers", {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, @@ -323,14 +333,14 @@ class PusherStore(PusherWorkerStore): "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room" + desc="get_throttle_params_by_room", ) params_by_room = {} for row in res: params_by_room[row["room_id"]] = { "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"] + "throttle_ms": row["throttle_ms"], } defer.returnValue(params_by_room) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 89a1f7e3d7..a1647e50a1 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - }, + keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore): keyvalues={ "room_id": room_id, "receipt_type": receipt_type, - "user_id": user_id + "user_id": user_id, }, retcol="event_id", desc="get_own_receipt_for_user", @@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_receipts_for_user(self, user_id, receipt_type): rows = yield self._simple_select_list( table="receipts_linearized", - keyvalues={ - "user_id": user_id, - "receipt_type": receipt_type, - }, + keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) @@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore): ) txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction( - "get_receipts_for_user_with_orderings", f + + rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + defer.returnValue( + { + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } + for row in rows + } ) - defer.returnValue({ - row[0]: { - "event_id": row[1], - "topological_ordering": row[2], - "stream_ordering": row[3], - } for row in rows - }) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): @@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """See get_linearized_receipts_for_room """ + def f(txn): if from_key: sql = ( @@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore): " room_id = ? AND stream_id > ? AND stream_id <= ?" ) - txn.execute( - sql, - (room_id, from_key, to_key) - ) + txn.execute(sql, (room_id, from_key, to_key)) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) - txn.execute( - sql, - (room_id, to_key) - ) + txn.execute(sql, (room_id, to_key)) rows = self.cursor_to_dict(txn) return rows - rows = yield self.runInteraction( - "get_linearized_receipts_for_room", f - ) + rows = yield self.runInteraction("get_linearized_receipts_for_room", f) if not rows: defer.returnValue([]) content = {} for row in rows: - content.setdefault( - row["event_id"], {} - ).setdefault( - row["receipt_type"], {} - )[row["user_id"]] = json.loads(row["data"]) - - defer.returnValue([{ - "type": "m.receipt", - "room_id": room_id, - "content": content, - }]) - - @cachedList(cached_method_name="_get_linearized_receipts_for_room", - list_name="room_ids", num_args=3, inlineCallbacks=True) + content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ + row["user_id"] + ] = json.loads(row["data"]) + + defer.returnValue( + [{"type": "m.receipt", "room_id": room_id, "content": content}] + ) + + @cachedList( + cached_method_name="_get_linearized_receipts_for_room", + list_name="room_ids", + num_args=3, + inlineCallbacks=True, + ) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) @@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore): sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" - ) % ( - ",".join(["?"] * len(room_ids)) - ) + ) % (",".join(["?"] * len(room_ids))) args = list(room_ids) args.extend([from_key, to_key]) @@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore): sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id <= ?" - ) % ( - ",".join(["?"] * len(room_ids)) - ) + ) % (",".join(["?"] * len(room_ids))) args = list(room_ids) args.append(to_key) @@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return self.cursor_to_dict(txn) - txn_results = yield self.runInteraction( - "_get_linearized_receipts_for_rooms", f - ) + txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. - room_event = results.setdefault(row["room_id"], { - "type": "m.receipt", - "room_id": row["room_id"], - "content": {}, - }) + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } @@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return ( - r[0:5] + (json.loads(r[5]), ) for r in txn - ) + return (r[0:5] + (json.loads(r[5]),) for r in txn) + return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): + def _invalidate_get_users_with_receipts_in_room( + self, room_id, receipt_type, user_id + ): if receipt_type != "m.read": return # Returns either an ObservableDeferred or the raw result res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False, + room_id, None, update_metrics=False ) # first handle the Deferred case @@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, - user_id, event_id, data, stream_id): + def insert_linearized_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_id, data, stream_id + ): """Inserts a read-receipt into the database if it's newer than the current RR Returns: int|None @@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore): table="events", retcols=["stream_ordering", "received_ts"], keyvalues={"event_id": event_id}, - allow_none=True + allow_none=True, ) stream_ordering = int(res["stream_ordering"]) if res else None @@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore): logger.debug( "Ignoring new receipt for %s in favour of existing " "one for later event %s", - event_id, eid, + event_id, + eid, ) return None - txn.call_after( - self.get_receipts_for_room.invalidate, (room_id, receipt_type) - ) + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, - room_id, receipt_type, user_id, + room_id, + receipt_type, + user_id, ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self.get_receipts_for_user.invalidate, (user_id, receipt_type) + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( - self._receipts_stream_cache.entity_has_changed, - room_id, stream_id + self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) txn.call_after( self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type) + (user_id, room_id, receipt_type), ) self._simple_delete_txn( @@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - } + }, ) self._simple_insert_txn( @@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, "event_id": event_id, "data": json.dumps(data), - } + }, ) if receipt_type == "m.read" and stream_ordering is not None: self._remove_old_push_actions_before_txn( - txn, - room_id=room_id, - user_id=user_id, - stream_ordering=stream_ordering, + txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) return rx_ts @@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore): event_ts = yield self.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, - room_id, receipt_type, user_id, linearized_event_id, + room_id, + receipt_type, + user_id, + linearized_event_id, data, stream_id=stream_id, ) @@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore): now = self._clock.time_msec() logger.debug( "RR for event %s in %s (%i ms old)", - linearized_event_id, room_id, now - event_ts, + linearized_event_id, + room_id, + now - event_ts, ) - yield self.insert_graph_receipt( - room_id, receipt_type, user_id, event_ids, data - ) + yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, - data): + def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): return self.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, - room_id, receipt_type, user_id, event_ids, data + room_id, + receipt_type, + user_id, + event_ids, + data, ) - def insert_graph_receipt_txn(self, txn, room_id, receipt_type, - user_id, event_ids, data): - txn.call_after( - self.get_receipts_for_room.invalidate, (room_id, receipt_type) - ) + def insert_graph_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_ids, data + ): + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, - room_id, receipt_type, user_id, + room_id, + receipt_type, + user_id, ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self.get_receipts_for_user.invalidate, (user_id, receipt_type) + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, @@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - } + }, ) self._simple_insert_txn( txn, @@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, "event_ids": json.dumps(event_ids), "data": json.dumps(data), - } + }, ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9b6c28892c..03a06a83d6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -32,18 +32,21 @@ class RegistrationWorkerStore(SQLBaseStore): super(RegistrationWorkerStore, self).__init__(db_conn, hs) self.config = hs.config + self.clock = hs.get_clock() @cached() def get_user_by_id(self, user_id): return self._simple_select_one( table="users", - keyvalues={ - "name": user_id, - }, + keyvalues={"name": user_id}, retcols=[ - "name", "password_hash", "is_guest", - "consent_version", "consent_server_notice_sent", - "appservice_id", "creation_ts", + "name", + "password_hash", + "is_guest", + "consent_version", + "consent_server_notice_sent", + "appservice_id", + "creation_ts", ], allow_none=True, desc="get_user_by_id", @@ -81,9 +84,163 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token + "get_user_by_access_token", self._query_for_auth, token + ) + + @cachedInlineCallbacks() + def get_expiration_ts_for_user(self, user_id): + """Get the expiration timestamp for the account bearing a given user ID. + + Args: + user_id (str): The ID of the user. + Returns: + defer.Deferred: None, if the account has no expiration timestamp, + otherwise int representation of the timestamp (as a number of + milliseconds since epoch). + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="expiration_ts_ms", + allow_none=True, + desc="get_expiration_ts_for_user", + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, + renewal_token=None): + """Updates the account validity properties of the given account, with the + given values. + + Args: + user_id (str): ID of the account to update properties for. + expiration_ts (int): New expiration date, as a timestamp in milliseconds + since epoch. + email_sent (bool): True means a renewal email has been sent for this + account and there's no need to send another one for the current validity + period. + renewal_token (str): Renewal token the user can use to extend the validity + of their account. Defaults to no token. + """ + def set_account_validity_for_user_txn(txn): + self._simple_update_txn( + txn=txn, + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={ + "expiration_ts_ms": expiration_ts, + "email_sent": email_sent, + "renewal_token": renewal_token, + }, + ) + self._invalidate_cache_and_stream( + txn, self.get_expiration_ts_for_user, (user_id,), + ) + + yield self.runInteraction( + "set_account_validity_for_user", + set_account_validity_for_user_txn, + ) + + @defer.inlineCallbacks + def set_renewal_token_for_user(self, user_id, renewal_token): + """Defines a renewal token for a given user. + + Args: + user_id (str): ID of the user to set the renewal token for. + renewal_token (str): Random unique string that will be used to renew the + user's account. + + Raises: + StoreError: The provided token is already set for another user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"renewal_token": renewal_token}, + desc="set_renewal_token_for_user", + ) + + @defer.inlineCallbacks + def get_user_from_renewal_token(self, renewal_token): + """Get a user ID from a renewal token. + + Args: + renewal_token (str): The renewal token to perform the lookup with. + + Returns: + defer.Deferred[str]: The ID of the user to which the token belongs. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcol="user_id", + desc="get_user_from_renewal_token", + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_renewal_token_for_user(self, user_id): + """Get the renewal token associated with a given user ID. + + Args: + user_id (str): The user ID to lookup a token for. + + Returns: + defer.Deferred[str]: The renewal token associated with this user ID. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="renewal_token", + desc="get_renewal_token_for_user", + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_users_expiring_soon(self): + """Selects users whose account will expire in the [now, now + renew_at] time + window (see configuration for account_validity for information on what renew_at + refers to). + + Returns: + Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + """ + def select_users_txn(txn, now_ms, renew_at): + sql = ( + "SELECT user_id, expiration_ts_ms FROM account_validity" + " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + ) + values = [False, now_ms, renew_at] + txn.execute(sql, values) + return self.cursor_to_dict(txn) + + res = yield self.runInteraction( + "get_users_expiring_soon", + select_users_txn, + self.clock.time_msec(), self.config.account_validity.renew_at, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def set_renewal_mail_status(self, user_id, email_sent): + """Sets or unsets the flag that indicates whether a renewal email has been sent + to the user (and the user hasn't renewed their account yet). + + Args: + user_id (str): ID of the user to set/unset the flag for. + email_sent (bool): Flag which indicates whether a renewal email has been sent + to this user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"email_sent": email_sent}, + desc="set_renewal_mail_status", ) @defer.inlineCallbacks @@ -143,10 +300,10 @@ class RegistrationWorkerStore(SQLBaseStore): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. """ + def f(txn): sql = ( - "SELECT name, password_hash FROM users" - " WHERE lower(name) = lower(?)" + "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" ) txn.execute(sql, (user_id,)) return dict(txn) @@ -156,6 +313,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" + def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.cursor_to_dict(txn) @@ -173,6 +331,7 @@ class RegistrationWorkerStore(SQLBaseStore): 3) bridged users who registered on the homeserver in the past 24 hours """ + def _count_daily_user_type(txn): yesterday = int(self._clock.time()) - (60 * 60 * 24) @@ -193,15 +352,18 @@ class RegistrationWorkerStore(SQLBaseStore): for row in txn: results[row[0]] = row[1] return results + return self.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): def _count_users(txn): - txn.execute(""" + txn.execute( + """ SELECT COALESCE(COUNT(*), 0) FROM users WHERE appservice_id IS NULL - """) + """ + ) count, = txn.fetchone() return count @@ -220,6 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore): avoid the case of ID 10000000 being pre-allocated, so us wasting the first (and shortest) many generated user IDs. """ + def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") @@ -227,7 +390,7 @@ class RegistrationWorkerStore(SQLBaseStore): found = set() - for user_id, in txn: + for (user_id,) in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) @@ -235,20 +398,22 @@ class RegistrationWorkerStore(SQLBaseStore): if i not in found: return i - defer.returnValue((yield self.runInteraction( - "find_next_generated_user_id", - _find_next_generated_user_id - ))) + defer.returnValue( + ( + yield self.runInteraction( + "find_next_generated_user_id", _find_next_generated_user_id + ) + ) + ) @defer.inlineCallbacks def get_3pid_guest_access_token(self, medium, address): ret = yield self._simple_select_one( "threepid_guest_access_tokens", - { - "medium": medium, - "address": address - }, - ["guest_access_token"], True, 'get_3pid_guest_access_token' + {"medium": medium, "address": address}, + ["guest_access_token"], + True, + 'get_3pid_guest_access_token', ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -266,8 +431,7 @@ class RegistrationWorkerStore(SQLBaseStore): Deferred[str|None]: user id or None if no user id/threepid mapping exists """ user_id = yield self.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, - medium, address + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) defer.returnValue(user_id) @@ -285,11 +449,9 @@ class RegistrationWorkerStore(SQLBaseStore): ret = self._simple_select_one_txn( txn, "user_threepids", - { - "medium": medium, - "address": address - }, - ['user_id'], True + {"medium": medium, "address": address}, + ['user_id'], + True, ) if ret: return ret['user_id'] @@ -297,41 +459,110 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert("user_threepids", { - "medium": medium, - "address": address, - }, { - "user_id": user_id, - "validated_at": validated_at, - "added_at": added_at, - }) + yield self._simple_upsert( + "user_threepids", + {"medium": medium, "address": address}, + {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, + ) @defer.inlineCallbacks def user_get_threepids(self, user_id): ret = yield self._simple_select_list( - "user_threepids", { - "user_id": user_id - }, + "user_threepids", + {"user_id": user_id}, ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids' + 'user_get_threepids', ) defer.returnValue(ret) def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( "user_threepids", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepids", + ) + + def add_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied a bind request to the given identity server on + behalf of the given user. We need to remember this in case the user + asks us to unbind the threepid. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + # We need to use an upsert, in case they user had already bound the + # threepid + return self._simple_upsert( + table="user_threepid_id_server", keyvalues={ "user_id": user_id, "medium": medium, "address": address, + "id_server": id_server, }, - desc="user_delete_threepids", + values={}, + insertion_values={}, + desc="add_user_bound_threepid", + ) + + def remove_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied an unbind request to the given identity server on + behalf of the given user, so we remove the mapping of threepid to + identity server. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + return self._simple_delete( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + desc="remove_user_bound_threepid", ) + def get_id_servers_user_bound(self, user_id, medium, address): + """Get the list of identity servers that the server proxied bind + requests to for given user and threepid + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[list[str]]: Resolves to a list of identity servers + """ + return self._simple_select_onecol( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + }, + retcol="id_server", + desc="get_id_servers_user_bound", + ) -class RegistrationStore(RegistrationWorkerStore, - background_updates.BackgroundUpdateStore): +class RegistrationStore( + RegistrationWorkerStore, background_updates.BackgroundUpdateStore +): def __init__(self, db_conn, hs): super(RegistrationStore, self).__init__(db_conn, hs) @@ -351,11 +582,17 @@ class RegistrationStore(RegistrationWorkerStore, columns=["creation_ts"], ) + self._account_validity = hs.config.account_validity + # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. self.register_noop_background_update("refresh_tokens_device_index") + self.register_background_update_handler( + "user_threepids_grandfather", self._bg_user_threepids_grandfather, + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -372,18 +609,22 @@ class RegistrationStore(RegistrationWorkerStore, yield self._simple_insert( "access_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - }, + {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id}, desc="add_access_token_to_user", ) - def register(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, user_type=None): + def register( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + ): """Attempts to register an account. Args: @@ -417,7 +658,7 @@ class RegistrationStore(RegistrationWorkerStore, appservice_id, create_profile_with_displayname, admin, - user_type + user_type, ) def _register( @@ -447,10 +688,7 @@ class RegistrationStore(RegistrationWorkerStore, self._simple_select_one_txn( txn, "users", - keyvalues={ - "name": user_id, - "is_guest": 1, - }, + keyvalues={"name": user_id, "is_guest": 1}, retcols=("name",), allow_none=False, ) @@ -458,10 +696,7 @@ class RegistrationStore(RegistrationWorkerStore, self._simple_update_one_txn( txn, "users", - keyvalues={ - "name": user_id, - "is_guest": 1, - }, + keyvalues={"name": user_id, "is_guest": 1}, updatevalues={ "password_hash": password_hash, "upgrade_ts": now, @@ -469,7 +704,7 @@ class RegistrationStore(RegistrationWorkerStore, "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, - } + }, ) else: self._simple_insert_txn( @@ -483,20 +718,31 @@ class RegistrationStore(RegistrationWorkerStore, "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, - } + }, ) + except self.database_engine.module.IntegrityError: - raise StoreError( - 400, "User ID already taken.", errcode=Codes.USER_IN_USE + raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + if self._account_validity.enabled: + now_ms = self.clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + } ) if token: # it's possible for this to get a conflict, but only for a single user # since tokens are namespaced based on their user ID txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" - " VALUES (?,?,?)", - (next_id, user_id, token,) + "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)", + (next_id, user_id, token), ) if create_profile_with_displayname: @@ -507,12 +753,10 @@ class RegistrationStore(RegistrationWorkerStore, # while everything else uses the full mxid. txn.execute( "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", - (user_id_obj.localpart, create_profile_with_displayname) + (user_id_obj.localpart, create_profile_with_displayname), ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) - ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) def user_set_password_hash(self, user_id, password_hash): @@ -521,22 +765,14 @@ class RegistrationStore(RegistrationWorkerStore, removes most of the entries subsequently anyway so it would be pointless. Use flush_user separately. """ + def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, - 'users', { - 'name': user_id - }, - { - 'password_hash': password_hash - } - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + txn, 'users', {'name': user_id}, {'password_hash': password_hash} ) - return self.runInteraction( - "user_set_password_hash", user_set_password_hash_txn - ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -549,16 +785,16 @@ class RegistrationStore(RegistrationWorkerStore, Raises: StoreError(404) if user not found """ + def f(txn): self._simple_update_one_txn( txn, table='users', - keyvalues={'name': user_id, }, - updatevalues={'consent_version': consent_version, }, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + keyvalues={'name': user_id}, + updatevalues={'consent_version': consent_version}, ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + return self.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): @@ -573,20 +809,19 @@ class RegistrationStore(RegistrationWorkerStore, Raises: StoreError(404) if user not found """ + def f(txn): self._simple_update_one_txn( txn, table='users', - keyvalues={'name': user_id, }, - updatevalues={'consent_server_notice_sent': consent_version, }, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + keyvalues={'name': user_id}, + updatevalues={'consent_server_notice_sent': consent_version}, ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + return self.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None): + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ Invalidate access tokens belonging to a user @@ -601,10 +836,9 @@ class RegistrationStore(RegistrationWorkerStore, defer.Deferred[list[str, int, str|None, int]]: a list of (token, token id, device id) for each of the deleted tokens """ + def f(txn): - keyvalues = { - "user_id": user_id, - } + keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id @@ -616,8 +850,9 @@ class RegistrationStore(RegistrationWorkerStore, values.append(except_token_id) txn.execute( - "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, - values + "SELECT token, id, device_id FROM access_tokens WHERE %s" + % where_clause, + values, ) tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] @@ -626,25 +861,16 @@ class RegistrationStore(RegistrationWorkerStore, txn, self.get_user_by_access_token, (token,) ) - txn.execute( - "DELETE FROM access_tokens WHERE %s" % where_clause, - values - ) + txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) return tokens_and_devices - return self.runInteraction( - "user_delete_access_tokens", f, - ) + return self.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): self._simple_delete_one_txn( - txn, - table="access_tokens", - keyvalues={ - "token": access_token - }, + txn, table="access_tokens", keyvalues={"token": access_token} ) self._invalidate_cache_and_stream( @@ -667,7 +893,7 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def save_or_get_3pid_guest_access_token( - self, medium, address, access_token, inviter_user_id + self, medium, address, access_token, inviter_user_id ): """ Gets the 3pid's guest access token if exists, else saves access_token. @@ -683,12 +909,13 @@ class RegistrationStore(RegistrationWorkerStore, deferred str: Whichever access token is persisted at the end of this function call. """ + def insert(txn): txn.execute( "INSERT INTO threepid_guest_access_tokens " "(medium, address, guest_access_token, first_inviter) " "VALUES (?, ?, ?, ?)", - (medium, address, access_token, inviter_user_id) + (medium, address, access_token, inviter_user_id), ) try: @@ -705,9 +932,7 @@ class RegistrationStore(RegistrationWorkerStore, """ return self._simple_insert( "users_pending_deactivation", - values={ - "user_id": user_id, - }, + values={"user_id": user_id}, desc="add_user_pending_deactivation", ) @@ -720,9 +945,7 @@ class RegistrationStore(RegistrationWorkerStore, # the table, so somehow duplicate entries have ended up in it. return self._simple_delete( "users_pending_deactivation", - keyvalues={ - "user_id": user_id, - }, + keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", ) @@ -738,3 +961,34 @@ class RegistrationStore(RegistrationWorkerStore, allow_none=True, desc="get_users_pending_deactivation", ) + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + ) + + yield self._end_background_update("user_threepids_grandfather") + + defer.returnValue(1) diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py index 880f047adb..f4c1c2a457 100644 --- a/synapse/storage/rejections.py +++ b/synapse/storage/rejections.py @@ -36,9 +36,7 @@ class RejectionsStore(SQLBaseStore): return self._simple_select_one_onecol( table="rejections", retcol="reason", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, allow_none=True, desc="get_rejection_reason", ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index a979d4860a..fe9d79d792 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -30,13 +30,11 @@ logger = logging.getLogger(__name__) OpsLevel = collections.namedtuple( - "OpsLevel", - ("ban_level", "kick_level", "redact_level",) + "OpsLevel", ("ban_level", "kick_level", "redact_level") ) RatelimitOverride = collections.namedtuple( - "RatelimitOverride", - ("messages_per_second", "burst_count",) + "RatelimitOverride", ("messages_per_second", "burst_count") ) @@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore): def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", - keyvalues={ - "is_public": True, - }, + keyvalues={"is_public": True}, retcol="room_id", desc="get_public_room_ids", ) @@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore): return self.runInteraction( "get_public_room_ids_at_stream_id", self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple + stream_id, + network_tuple=network_tuple, ) - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): return { rm for rm, vis in self.get_published_at_stream_id_txn( @@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore): if network_tuple: # We want to get from a particular list. No aggregation required. - sql = (""" + sql = """ SELECT room_id, visibility FROM public_room_list_stream INNER JOIN ( SELECT room_id, max(stream_id) AS stream_id @@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore): WHERE stream_id <= ? %s GROUP BY room_id ) grouped USING (room_id, stream_id) - """) + """ if network_tuple.appservice_id is not None: txn.execute( sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + (stream_id, network_tuple.appservice_id, network_tuple.network_id), ) else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) + txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) return dict(txn) else: # We want to get from all lists, so we need to aggregate the results logger.info("Executing full list") - sql = (""" + sql = """ SELECT room_id, visibility FROM public_room_list_stream INNER JOIN ( @@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore): WHERE stream_id <= ? GROUP BY room_id, appservice_id, network_id ) grouped USING (room_id, stream_id) - """) + """ - txn.execute( - sql, - (stream_id,) - ) + txn.execute(sql, (stream_id,)) results = {} # A room is visible if its visible on any list. @@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore): return results - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): + def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): def get_public_room_changes_txn(txn): then_rooms = self.get_public_room_ids_at_stream_id_txn( txn, prev_stream_id, network_tuple @@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore): txn, new_stream_id, network_tuple ) - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) + now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis) now_rooms_not_visible = set( rm for rm, vis in now_rooms_dict.items() if not vis ) @@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore): def is_room_blocked(self, room_id): return self._simple_select_one_onecol( table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, retcol="1", allow_none=True, desc="is_room_blocked", @@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore): ) if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) + defer.returnValue( + RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + ) + ) else: defer.returnValue(None) class RoomStore(RoomWorkerStore, SearchStore): - @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore): StoreError if the room could not be stored. """ try: + def store_room_txn(txn, next_id): self._simple_insert_txn( txn, @@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore): "stream_id": next_id, "room_id": room_id, "visibility": is_public, - } + }, ) + with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "store_room_txn", - store_room_txn, next_id, - ) + yield self.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore): "visibility": is_public, "appservice_id": None, "network_id": None, - } + }, ) with self._public_room_id_gen.get_next() as next_id: yield self.runInteraction( - "set_room_is_public", - set_room_is_public_txn, next_id, + "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @defer.inlineCallbacks - def set_room_is_public_appservice(self, room_id, appservice_id, network_id, - is_public): + def set_room_is_public_appservice( + self, room_id, appservice_id, network_id, is_public + ): """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore): is_public (bool): Whether to publish or unpublish the room from the list. """ + def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: @@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore): values={ "appservice_id": appservice_id, "network_id": network_id, - "room_id": room_id + "room_id": room_id, }, ) except self.database_engine.module.IntegrityError: @@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore): keyvalues={ "appservice_id": appservice_id, "network_id": network_id, - "room_id": room_id + "room_id": room_id, }, ) @@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore): "visibility": is_public, "appservice_id": appservice_id, "network_id": network_id, - } + }, ) with self._public_room_id_gen.get_next() as next_id: yield self.runInteraction( "set_room_is_public_appservice", - set_room_is_public_appservice_txn, next_id, + set_room_is_public_appservice_txn, + next_id, ) self.hs.get_notifier().on_new_replication_data() @@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.runInteraction( - "get_rooms", f - ) + return self.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ) self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"], + txn, event, "content.topic", event.content["topic"] ) def _store_room_name_txn(self, txn, event): @@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore): "event_id": event.event_id, "room_id": event.room_id, "name": event.content["name"], - } + }, ) self.store_event_search_txn( - txn, event, "content.name", event.content["name"], + txn, event, "content.name", event.content["name"] ) def _store_room_message_txn(self, txn, event): if hasattr(event, "content") and "body" in event.content: self.store_event_search_txn( - txn, event, "content.body", event.content["body"], + txn, event, "content.body", event.content["body"] ) def _store_history_visibility_txn(self, txn, event): @@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore): " (event_id, room_id, %(key)s)" " VALUES (?, ?, ?)" % {"key": key} ) - txn.execute(sql, ( - event.event_id, - event.room_id, - event.content[key] - )) - - def add_event_report(self, room_id, event_id, user_id, reason, content, - received_ts): + txn.execute(sql, (event.event_id, event.room_id, event.content[key])) + + def add_event_report( + self, room_id, event_id, user_id, reason, content, received_ts + ): next_id = self._event_reports_id_gen.get_next() return self._simple_insert( table="event_reports", @@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore): "reason": reason, "content": json.dumps(content), }, - desc="add_event_report" + desc="add_event_report", ) def get_current_public_room_stream_id(self): @@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore): def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(txn): - sql = (""" + sql = """ SELECT stream_id, room_id, visibility, appservice_id, network_id FROM public_room_list_stream WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? - """) + """ - txn.execute(sql, (prev_id, current_id, limit,)) + txn.execute(sql, (prev_id, current_id, limit)) return txn.fetchall() if prev_id == current_id: return defer.succeed([]) - return self.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) + return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore): """ yield self._simple_upsert( table="blocked_rooms", - keyvalues={ - "room_id": room_id, - }, + keyvalues={"room_id": room_id}, values={}, - insertion_values={ - "user_id": user_id, - }, + insertion_values={"user_id": user_id}, desc="block_room", ) yield self.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, - self.is_room_blocked, (room_id,), + self.is_room_blocked, + (room_id,), ) def get_media_mxcs_in_room(self, room_id): @@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ + def _get_media_mxcs_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] @@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines the associated media """ + def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) total_media_quarantined = 0 # Now update all the tables to set the quarantined_by flag - txn.executemany(""" + txn.executemany( + """ UPDATE local_media_repository SET quarantined_by = ? WHERE media_id = ? - """, ((quarantined_by, media_id) for media_id in local_mxcs)) + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) txn.executemany( """ @@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ( (quarantined_by, origin, media_id) for origin, media_id in remote_mxcs - ) + ), ) total_media_quarantined += len(local_mxcs) @@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return total_media_quarantined return self.runInteraction( - "quarantine_media_in_room", - _quarantine_media_in_room_txn, + "quarantine_media_in_room", _quarantine_media_in_room_txn ) def _get_media_mxcs_in_room_txn(self, txn, room_id): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 592c1bcd33..57df17bcc2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -35,28 +35,22 @@ logger = logging.getLogger(__name__) RoomsForUser = namedtuple( - "RoomsForUser", - ("room_id", "sender", "membership", "event_id", "stream_ordering") + "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering") ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", - ("room_id", "stream_ordering",) + "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") ) # We store this using a namedtuple so that we save about 3x space over using a # dict. -ProfileInfo = namedtuple( - "ProfileInfo", ("avatar_url", "display_name") -) +ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) # "members" points to a truncated list of (user_id, event_id) tuples for users of # a given membership type, suitable for use in calculating heroes for a room. # "count" points to the total numberr of users of a given membership type. -MemberSummary = namedtuple( - "MemberSummary", ("members", "count") -) +MemberSummary = namedtuple("MemberSummary", ("members", "count")) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of all hosts currently in the room """ user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) defer.returnValue(hosts) @@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" ) - txn.execute(sql, (room_id, Membership.JOIN,)) + txn.execute(sql, (room_id, Membership.JOIN)) return [to_ascii(r[0]) for r in txn] + return self.runInteraction("get_users_in_room", f) @cached(max_entries=100000) @@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): A deferred list of RoomsForUser. """ - return self.get_rooms_for_user_where_membership_is( - user_id, [Membership.INVITE] - ) + return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) @defer.inlineCallbacks def get_invite_for_user_in_room(self, user_id, room_id): @@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, - user_id, membership_list + user_id, + membership_list, ) - def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, - membership_list): + def _get_rooms_for_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): do_invite = Membership.INVITE in membership_list membership_list = [m for m in membership_list if m != Membership.INVITE] @@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) % (where_clause,) txn.execute(sql, args) - results = [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] if do_invite: sql = ( @@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id,)) - results.extend(RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) for r in self.cursor_to_dict(txn)) + results.extend( + RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) + for r in self.cursor_to_dict(txn) + ) return results @@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): of the most recent join for that user and room. """ rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN], + user_id, membership_list=[Membership.JOIN] + ) + defer.returnValue( + frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + ) ) - defer.returnValue(frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - )) @defer.inlineCallbacks def get_rooms_for_user(self, user_id, on_invalidate=None): """Returns a set of room_ids the user is currently joined to """ rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate, + user_id, on_invalidate=on_invalidate ) defer.returnValue(frozenset(r.room_id for r in rooms)) @@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of users who share a room with `user_id` """ room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) @@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): current_state_ids = yield context.get_current_state_ids(self) result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, - event=event, - context=context, + event.room_id, state_group, current_state_ids, event=event, context=context ) defer.returnValue(result) @@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry, + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=100000) - def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context, event=None, context=None): + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. @@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the hit ratio counts. After all, we don't populate the cache if we # miss it here event_map = self._get_events_from_cache( - member_event_ids, - allow_rejected=False, - update_metrics=False, + member_event_ids, allow_rejected=False, update_metrics=False ) missing_member_event_ids = [] @@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=missing_member_event_ids, - retcols=('user_id', 'display_name', 'avatar_url',), - keyvalues={ - "membership": Membership.JOIN, - }, + retcols=('user_id', 'display_name', 'avatar_url'), + keyvalues={"membership": Membership.JOIN}, batch_size=500, desc="_get_joined_users_from_context", ) - users_in_room.update({ - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), - ) - for row in rows - }) + users_in_room.update( + { + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + } + ) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() return self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry, + room_id, state_group, state_entry.state, state_entry=state_entry ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" + def f(txn): sql = ( "SELECT" @@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id, room_id)) rows = txn.fetchall() return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) defer.returnValue(count == 0) @@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore): "avatar_url": event.content.get("avatar_url", None), } for event in events - ] + ], ) for event in events: txn.call_after( self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering + event.state_key, + event.internal_metadata.stream_ordering, ) txn.call_after( self.get_invited_rooms_for_user.invalidate, (event.state_key,) @@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore): "inviter": event.sender, "room_id": event.room_id, "stream_id": event.internal_metadata.stream_ordering, - } + }, ) else: sql = ( @@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore): " AND replaced_by is NULL" ) - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): @@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore): ) def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) + txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): sql = ( "UPDATE" @@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore): ) txn.execute(sql, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.did_forget, (user_id, room_id,), - ) + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + return self.runInteraction("forget_membership", f) @defer.inlineCallbacks @@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore): INSERT_CLUMP_SIZE = 1000 def add_membership_profile_txn(txn): - sql = (""" + sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events INNER JOIN event_json USING (event_id) @@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore): AND type = 'm.room.member' ORDER BY stream_ordering DESC LIMIT ? - """) + """ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore): avatar_url = content.get("avatar_url", None) if display_name or avatar_url: - to_update.append(( - display_name, avatar_url, event_id, room_id - )) + to_update.append((display_name, avatar_url, event_id, room_id)) - to_update_sql = (""" + to_update_sql = """ UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? - """) + """ for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index:index + INSERT_CLUMP_SIZE] + clump = to_update[index : index + INSERT_CLUMP_SIZE] txn.executemany(to_update_sql, clump) progress = { @@ -789,7 +790,7 @@ class _JoinedHostsCache(object): self.hosts_to_joined_users.pop(host, None) else: joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry, + self.room_id, state_entry ) self.hosts_to_joined_users = {} diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/schema/delta/13/v13.sql index 5eb93b38b2..f8649e5d99 100644 --- a/synapse/storage/schema/delta/13/v13.sql +++ b/synapse/storage/schema/delta/13/v13.sql @@ -13,19 +13,7 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS application_services( - id INTEGER PRIMARY KEY AUTOINCREMENT, - url TEXT, - token TEXT, - hs_token TEXT, - sender TEXT, - UNIQUE(token) -); - -CREATE TABLE IF NOT EXISTS application_services_regex( - id INTEGER PRIMARY KEY AUTOINCREMENT, - as_id BIGINT UNSIGNED NOT NULL, - namespace INTEGER, /* enum[room_id|room_alias|user_id] */ - regex TEXT, - FOREIGN KEY(as_id) REFERENCES application_services(id) -); +/* We used to create a tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py deleted file mode 100644 index 4d725b92fe..0000000000 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import simplejson as json - -logger = logging.getLogger(__name__) - - -def run_create(cur, *args, **kwargs): - cur.execute("SELECT id, regex FROM application_services_regex") - for row in cur.fetchall(): - try: - logger.debug("Checking %s..." % row[0]) - json.loads(row[1]) - except ValueError: - # row isn't in json, make it so. - string_regex = row[1] - new_regex = json.dumps({ - "regex": string_regex, - "exclusive": True - }) - cur.execute( - "UPDATE application_services_regex SET regex=? WHERE id=?", - (new_regex, row[0]) - ) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/schema/delta/16/unique_constraints.sql index fecf11118c..5b8de52c33 100644 --- a/synapse/storage/schema/delta/16/unique_constraints.sql +++ b/synapse/storage/schema/delta/16/unique_constraints.sql @@ -18,14 +18,6 @@ DROP INDEX IF EXISTS room_memberships_event_id; CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); -- -DELETE FROM feedback WHERE rowid not in ( - SELECT MIN(rowid) FROM feedback GROUP BY event_id -); - -DROP INDEX IF EXISTS feedback_event_id; -CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id); - --- DELETE FROM topics WHERE rowid not in ( SELECT MIN(rowid) FROM topics GROUP BY event_id ); diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql index 5f508af7a9..acea7483bd 100644 --- a/synapse/storage/schema/delta/24/stats_reporting.sql +++ b/synapse/storage/schema/delta/24/stats_reporting.sql @@ -1,4 +1,4 @@ -/* Copyright 2015, 2016 OpenMarket Ltd +/* Copyright 2019 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,6 @@ * limitations under the License. */ --- Should only ever contain one row -CREATE TABLE IF NOT EXISTS stats_reporting( - -- The stream ordering token which was most recently reported as stats - reported_stream_token INTEGER, - -- The time (seconds since epoch) stats were most recently reported - reported_time BIGINT -); + /* We used to create a table called stats_reporting, but this is no longer + * used and is removed in delta 54. + */ \ No newline at end of file diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql index 706fe1dcf4..e85699e82e 100644 --- a/synapse/storage/schema/delta/30/state_stream.sql +++ b/synapse/storage/schema/delta/30/state_stream.sql @@ -14,15 +14,10 @@ */ -/** - * The positions in the event stream_ordering when the current_state was - * replaced by the state at the event. +/* We used to create a table called current_state_resets, but this is no + * longer used and is removed in delta 54. */ -CREATE TABLE IF NOT EXISTS current_state_resets( - event_stream_ordering BIGINT PRIMARY KEY NOT NULL -); - /* The outlier events that have aquired a state group typically through * backfill. This is tracked separately to the events table, as assigning a * state group change the position of the existing event in the stream diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/schema/delta/32/remove_indices.sql index f859be46a6..4219cdd06a 100644 --- a/synapse/storage/schema/delta/32/remove_indices.sql +++ b/synapse/storage/schema/delta/32/remove_indices.sql @@ -24,13 +24,9 @@ DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT -- The following indices were unused DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; diff --git a/synapse/storage/schema/full_schemas/11/room_aliases.sql b/synapse/storage/schema/delta/53/user_threepid_id.sql index 71a91f8ec9..80c2c573b6 100644 --- a/synapse/storage/schema/full_schemas/11/room_aliases.sql +++ b/synapse/storage/schema/delta/53/user_threepid_id.sql @@ -1,4 +1,4 @@ -/* Copyright 2014-2016 OpenMarket Ltd +/* Copyright 2019 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,17 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS room_aliases( - room_alias TEXT NOT NULL, - room_id TEXT NOT NULL +-- Tracks which identity server a user bound their threepid via. +CREATE TABLE user_threepid_id_server ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + id_server TEXT NOT NULL ); -CREATE TABLE IF NOT EXISTS room_alias_servers( - room_alias TEXT NOT NULL, - server TEXT NOT NULL +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( + user_id, medium, address, id_server ); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_threepids_grandfather', '{}'); diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql new file mode 100644 index 0000000000..2357626000 --- /dev/null +++ b/synapse/storage/schema/delta/54/account_validity.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS account_validity; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS account_validity ( + user_id TEXT PRIMARY KEY, + expiration_ts_ms BIGINT NOT NULL, + email_sent BOOLEAN NOT NULL, + renewal_token TEXT +); + +CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) +CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/schema/delta/54/drop_legacy_tables.sql new file mode 100644 index 0000000000..dbbe682697 --- /dev/null +++ b/synapse/storage/schema/delta/54/drop_legacy_tables.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we need to do this first due to foreign constraints +DROP TABLE IF EXISTS application_services_regex; + +DROP TABLE IF EXISTS application_services; +DROP TABLE IF EXISTS transaction_id_to_pdu; +DROP TABLE IF EXISTS stats_reporting; +DROP TABLE IF EXISTS current_state_resets; +DROP TABLE IF EXISTS event_content_hashes; +DROP TABLE IF EXISTS event_destinations; +DROP TABLE IF EXISTS event_edge_hashes; +DROP TABLE IF EXISTS event_signatures; +DROP TABLE IF EXISTS feedback; +DROP TABLE IF EXISTS room_hosts; +DROP TABLE IF EXISTS server_tls_certificates; +DROP TABLE IF EXISTS state_forward_extremities; diff --git a/synapse/storage/schema/full_schemas/11/profiles.sql b/synapse/storage/schema/delta/54/drop_presence_list.sql index b314e6df75..e6ee70c623 100644 --- a/synapse/storage/schema/full_schemas/11/profiles.sql +++ b/synapse/storage/schema/delta/54/drop_presence_list.sql @@ -1,4 +1,4 @@ -/* Copyright 2014-2016 OpenMarket Ltd +/* Copyright 2019 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -12,8 +12,5 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -CREATE TABLE IF NOT EXISTS profiles( - user_id TEXT NOT NULL, - displayname TEXT, - avatar_url TEXT -); + +DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql deleted file mode 100644 index bccd1c6f74..0000000000 --- a/synapse/storage/schema/full_schemas/11/event_edges.sql +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS event_forward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_backward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_edges( - event_id TEXT NOT NULL, - prev_event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - -- We no longer insert prev_state into this table, so all new rows will have - -- is_state as false. - is_state BOOL NOT NULL, - UNIQUE (event_id, prev_event_id, room_id, is_state) -); - -CREATE INDEX ev_edges_id ON event_edges(event_id); -CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); - - -CREATE TABLE IF NOT EXISTS room_depth( - room_id TEXT NOT NULL, - min_depth INTEGER NOT NULL, - UNIQUE (room_id) -); - -CREATE INDEX room_depth_room ON room_depth(room_id); - - -create TABLE IF NOT EXISTS event_destinations( - event_id TEXT NOT NULL, - destination TEXT NOT NULL, - delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered - UNIQUE (event_id, destination) -); - -CREATE INDEX event_destinations_id ON event_destinations(event_id); - - -CREATE TABLE IF NOT EXISTS state_forward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX st_extrem_keys ON state_forward_extremities( - room_id, type, state_key -); -CREATE INDEX st_extrem_id ON state_forward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_auth( - event_id TEXT NOT NULL, - auth_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, auth_id, room_id) -); - -CREATE INDEX evauth_edges_id ON event_auth(event_id); -CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/schema/full_schemas/11/event_signatures.sql b/synapse/storage/schema/full_schemas/11/event_signatures.sql deleted file mode 100644 index 00ce85980e..0000000000 --- a/synapse/storage/schema/full_schemas/11/event_signatures.sql +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS event_content_hashes ( - event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, algorithm) -); - -CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id); - - -CREATE TABLE IF NOT EXISTS event_reference_hashes ( - event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, algorithm) -); - -CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); - - -CREATE TABLE IF NOT EXISTS event_signatures ( - event_id TEXT, - signature_name TEXT, - key_id TEXT, - signature bytea, - UNIQUE (event_id, signature_name, key_id) -); - -CREATE INDEX event_signatures_id ON event_signatures(event_id); - - -CREATE TABLE IF NOT EXISTS event_edge_hashes( - event_id TEXT, - prev_event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, prev_event_id, algorithm) -); - -CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id); diff --git a/synapse/storage/schema/full_schemas/11/im.sql b/synapse/storage/schema/full_schemas/11/im.sql deleted file mode 100644 index dfbbf9fd54..0000000000 --- a/synapse/storage/schema/full_schemas/11/im.sql +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS events( - stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT, - topological_ordering BIGINT NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - room_id TEXT NOT NULL, - content TEXT NOT NULL, - unrecognized_keys TEXT, - processed BOOL NOT NULL, - outlier BOOL NOT NULL, - depth BIGINT DEFAULT 0 NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX events_stream_ordering ON events (stream_ordering); -CREATE INDEX events_topological_ordering ON events (topological_ordering); -CREATE INDEX events_room_id ON events (room_id); - - -CREATE TABLE IF NOT EXISTS event_json( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - internal_metadata TEXT NOT NULL, - json TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX event_json_room_id ON event_json(room_id); - - -CREATE TABLE IF NOT EXISTS state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - prev_state TEXT, - UNIQUE (event_id) -); - -CREATE INDEX state_events_room_id ON state_events (room_id); -CREATE INDEX state_events_type ON state_events (type); -CREATE INDEX state_events_state_key ON state_events (state_key); - - -CREATE TABLE IF NOT EXISTS current_state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - UNIQUE (room_id, type, state_key) -); - -CREATE INDEX curr_events_event_id ON current_state_events (event_id); -CREATE INDEX current_state_events_room_id ON current_state_events (room_id); -CREATE INDEX current_state_events_type ON current_state_events (type); -CREATE INDEX current_state_events_state_key ON current_state_events (state_key); - -CREATE TABLE IF NOT EXISTS room_memberships( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - sender TEXT NOT NULL, - room_id TEXT NOT NULL, - membership TEXT NOT NULL -); - -CREATE INDEX room_memberships_event_id ON room_memberships (event_id); -CREATE INDEX room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX room_memberships_user_id ON room_memberships (user_id); - -CREATE TABLE IF NOT EXISTS feedback( - event_id TEXT NOT NULL, - feedback_type TEXT, - target_event_id TEXT, - sender TEXT, - room_id TEXT -); - -CREATE TABLE IF NOT EXISTS topics( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - topic TEXT NOT NULL -); - -CREATE INDEX topics_event_id ON topics(event_id); -CREATE INDEX topics_room_id ON topics(room_id); - -CREATE TABLE IF NOT EXISTS room_names( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - name TEXT NOT NULL -); - -CREATE INDEX room_names_event_id ON room_names(event_id); -CREATE INDEX room_names_room_id ON room_names(room_id); - -CREATE TABLE IF NOT EXISTS rooms( - room_id TEXT PRIMARY KEY NOT NULL, - is_public BOOL, - creator TEXT -); - -CREATE TABLE IF NOT EXISTS room_hosts( - room_id TEXT NOT NULL, - host TEXT NOT NULL, - UNIQUE (room_id, host) -); - -CREATE INDEX room_hosts_room_id ON room_hosts (room_id); diff --git a/synapse/storage/schema/full_schemas/11/keys.sql b/synapse/storage/schema/full_schemas/11/keys.sql deleted file mode 100644 index ca0ca1b694..0000000000 --- a/synapse/storage/schema/full_schemas/11/keys.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS server_tls_certificates( - server_name TEXT, -- Server name. - fingerprint TEXT, -- Certificate fingerprint. - from_server TEXT, -- Which key server the certificate was fetched from. - ts_added_ms BIGINT, -- When the certifcate was added. - tls_certificate bytea, -- DER encoded x509 certificate. - UNIQUE (server_name, fingerprint) -); - -CREATE TABLE IF NOT EXISTS server_signature_keys( - server_name TEXT, -- Server name. - key_id TEXT, -- Key version. - from_server TEXT, -- Which key server the key was fetched form. - ts_added_ms BIGINT, -- When the key was added. - verify_key bytea, -- NACL verification key. - UNIQUE (server_name, key_id) -); diff --git a/synapse/storage/schema/full_schemas/11/media_repository.sql b/synapse/storage/schema/full_schemas/11/media_repository.sql deleted file mode 100644 index 9c264d6ece..0000000000 --- a/synapse/storage/schema/full_schemas/11/media_repository.sql +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS local_media_repository ( - media_id TEXT, -- The id used to refer to the media. - media_type TEXT, -- The MIME-type of the media. - media_length INTEGER, -- Length of the media in bytes. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - user_id TEXT, -- The user who uploaded the file. - UNIQUE (media_id) -); - -CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_method TEXT, -- The method used to make the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - UNIQUE ( - media_id, thumbnail_width, thumbnail_height, thumbnail_type - ) -); - -CREATE INDEX local_media_repository_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); - -CREATE TABLE IF NOT EXISTS remote_media_cache ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media on that server. - media_type TEXT, -- The MIME-type of the media. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - media_length INTEGER, -- Length of the media in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE (media_origin, media_id) -); - -CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_method TEXT, -- The method used to make the thumbnail - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE ( - media_origin, media_id, thumbnail_width, thumbnail_height, - thumbnail_type - ) -); diff --git a/synapse/storage/schema/full_schemas/11/presence.sql b/synapse/storage/schema/full_schemas/11/presence.sql deleted file mode 100644 index 492725994c..0000000000 --- a/synapse/storage/schema/full_schemas/11/presence.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS presence( - user_id TEXT NOT NULL, - state VARCHAR(20), - status_msg TEXT, - mtime BIGINT -- miliseconds since last state change -); - --- For each of /my/ users which possibly-remote users are allowed to see their --- presence state -CREATE TABLE IF NOT EXISTS presence_allow_inbound( - observed_user_id TEXT NOT NULL, - observer_user_id TEXT NOT NULL -- a UserID, -); - --- For each of /my/ users (watcher), which possibly-remote users are they --- watching? -CREATE TABLE IF NOT EXISTS presence_list( - user_id TEXT NOT NULL, - observed_user_id TEXT NOT NULL, -- a UserID, - accepted BOOLEAN NOT NULL -); diff --git a/synapse/storage/schema/full_schemas/11/redactions.sql b/synapse/storage/schema/full_schemas/11/redactions.sql deleted file mode 100644 index 318f0d9aa5..0000000000 --- a/synapse/storage/schema/full_schemas/11/redactions.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS redactions ( - event_id TEXT NOT NULL, - redacts TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX redactions_event_id ON redactions (event_id); -CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/schema/full_schemas/11/state.sql b/synapse/storage/schema/full_schemas/11/state.sql deleted file mode 100644 index b901e0f017..0000000000 --- a/synapse/storage/schema/full_schemas/11/state.sql +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS state_groups( - id INTEGER PRIMARY KEY, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS state_groups_state( - state_group INTEGER NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS event_to_state_groups( - event_id TEXT NOT NULL, - state_group INTEGER NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX state_groups_id ON state_groups(id); - -CREATE INDEX state_groups_state_id ON state_groups_state(state_group); -CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); -CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/schema/full_schemas/11/transactions.sql b/synapse/storage/schema/full_schemas/11/transactions.sql deleted file mode 100644 index f6a058832e..0000000000 --- a/synapse/storage/schema/full_schemas/11/transactions.sql +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores what transaction ids we have received and what our response was -CREATE TABLE IF NOT EXISTS received_transactions( - transaction_id TEXT, - origin TEXT, - ts BIGINT, - response_code INTEGER, - response_json bytea, - has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx - UNIQUE (transaction_id, origin) -); - -CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; - --- For sent transactions only. -CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( - transaction_id INTEGER, - destination TEXT, - pdu_id TEXT, - pdu_origin TEXT -); - -CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); -CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); - --- To track destination health -CREATE TABLE IF NOT EXISTS destinations( - destination TEXT PRIMARY KEY, - retry_last_ts BIGINT, - retry_interval INTEGER -); diff --git a/synapse/storage/schema/full_schemas/11/users.sql b/synapse/storage/schema/full_schemas/11/users.sql deleted file mode 100644 index 6c1d4c34a1..0000000000 --- a/synapse/storage/schema/full_schemas/11/users.sql +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS users( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - password_hash TEXT, - creation_ts BIGINT, - admin SMALLINT DEFAULT 0 NOT NULL, - UNIQUE(name) -); - -CREATE TABLE IF NOT EXISTS access_tokens( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - device_id TEXT, - token TEXT NOT NULL, - last_used BIGINT, - UNIQUE(token) -); - -CREATE TABLE IF NOT EXISTS user_ips ( - user TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen BIGINT NOT NULL, - UNIQUE (user, access_token, ip, user_agent) -); - -CREATE INDEX user_ips_user ON user_ips(user); diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/schema/full_schemas/16/application_services.sql index aee0e68473..883fcd10b2 100644 --- a/synapse/storage/schema/full_schemas/16/application_services.sql +++ b/synapse/storage/schema/full_schemas/16/application_services.sql @@ -13,22 +13,11 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS application_services( - id BIGINT PRIMARY KEY, - url TEXT, - token TEXT, - hs_token TEXT, - sender TEXT, - UNIQUE(token) -); +/* We used to create tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ -CREATE TABLE IF NOT EXISTS application_services_regex( - id BIGINT PRIMARY KEY, - as_id BIGINT NOT NULL, - namespace INTEGER, /* enum[room_id|room_alias|user_id] */ - regex TEXT, - FOREIGN KEY(as_id) REFERENCES application_services(id) -); CREATE TABLE IF NOT EXISTS application_services_state( as_id TEXT PRIMARY KEY, diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql index 6b5a5a88fa..10ce2aa7a0 100644 --- a/synapse/storage/schema/full_schemas/16/event_edges.sql +++ b/synapse/storage/schema/full_schemas/16/event_edges.sql @@ -13,6 +13,11 @@ * limitations under the License. */ +/* We used to create tables called event_destinations and + * state_forward_extremities, but these are no longer used and are removed in + * delta 54. + */ + CREATE TABLE IF NOT EXISTS event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, @@ -54,31 +59,6 @@ CREATE TABLE IF NOT EXISTS room_depth( CREATE INDEX room_depth_room ON room_depth(room_id); - -create TABLE IF NOT EXISTS event_destinations( - event_id TEXT NOT NULL, - destination TEXT NOT NULL, - delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered - UNIQUE (event_id, destination) -); - -CREATE INDEX event_destinations_id ON event_destinations(event_id); - - -CREATE TABLE IF NOT EXISTS state_forward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX st_extrem_keys ON state_forward_extremities( - room_id, type, state_key -); -CREATE INDEX st_extrem_id ON state_forward_extremities(event_id); - - CREATE TABLE IF NOT EXISTS event_auth( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/schema/full_schemas/16/event_signatures.sql index 00ce85980e..95826da431 100644 --- a/synapse/storage/schema/full_schemas/16/event_signatures.sql +++ b/synapse/storage/schema/full_schemas/16/event_signatures.sql @@ -13,15 +13,9 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS event_content_hashes ( - event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, algorithm) -); - -CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id); - + /* We used to create tables called event_content_hashes and event_edge_hashes, + * but these are no longer used and are removed in delta 54. + */ CREATE TABLE IF NOT EXISTS event_reference_hashes ( event_id TEXT, @@ -42,14 +36,3 @@ CREATE TABLE IF NOT EXISTS event_signatures ( ); CREATE INDEX event_signatures_id ON event_signatures(event_id); - - -CREATE TABLE IF NOT EXISTS event_edge_hashes( - event_id TEXT, - prev_event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, prev_event_id, algorithm) -); - -CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id); diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql index 5f5cb8d01d..a1a2aa8e5b 100644 --- a/synapse/storage/schema/full_schemas/16/im.sql +++ b/synapse/storage/schema/full_schemas/16/im.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/* We used to create tables called room_hosts and feedback, + * but these are no longer used and are removed in delta 54. + */ + CREATE TABLE IF NOT EXISTS events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, @@ -91,15 +95,6 @@ CREATE TABLE IF NOT EXISTS room_memberships( CREATE INDEX room_memberships_room_id ON room_memberships (room_id); CREATE INDEX room_memberships_user_id ON room_memberships (user_id); -CREATE TABLE IF NOT EXISTS feedback( - event_id TEXT NOT NULL, - feedback_type TEXT, - target_event_id TEXT, - sender TEXT, - room_id TEXT, - UNIQUE (event_id) -); - CREATE TABLE IF NOT EXISTS topics( event_id TEXT NOT NULL, room_id TEXT NOT NULL, @@ -123,11 +118,3 @@ CREATE TABLE IF NOT EXISTS rooms( is_public BOOL, creator TEXT ); - -CREATE TABLE IF NOT EXISTS room_hosts( - room_id TEXT NOT NULL, - host TEXT NOT NULL, - UNIQUE (room_id, host) -); - -CREATE INDEX room_hosts_room_id ON room_hosts (room_id); diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/schema/full_schemas/16/keys.sql index ca0ca1b694..11cdffdbb3 100644 --- a/synapse/storage/schema/full_schemas/16/keys.sql +++ b/synapse/storage/schema/full_schemas/16/keys.sql @@ -12,14 +12,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -CREATE TABLE IF NOT EXISTS server_tls_certificates( - server_name TEXT, -- Server name. - fingerprint TEXT, -- Certificate fingerprint. - from_server TEXT, -- Which key server the certificate was fetched from. - ts_added_ms BIGINT, -- When the certifcate was added. - tls_certificate bytea, -- DER encoded x509 certificate. - UNIQUE (server_name, fingerprint) -); + +-- we used to create a table called server_tls_certificates, but this is no +-- longer used, and is removed in delta 54. CREATE TABLE IF NOT EXISTS server_signature_keys( server_name TEXT, -- Server name. diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql index 283136df20..01d2d8f833 100644 --- a/synapse/storage/schema/full_schemas/16/presence.sql +++ b/synapse/storage/schema/full_schemas/16/presence.sql @@ -28,13 +28,5 @@ CREATE TABLE IF NOT EXISTS presence_allow_inbound( UNIQUE (observed_user_id, observer_user_id) ); --- For each of /my/ users (watcher), which possibly-remote users are they --- watching? -CREATE TABLE IF NOT EXISTS presence_list( - user_id TEXT NOT NULL, - observed_user_id TEXT NOT NULL, -- a UserID, - accepted BOOLEAN NOT NULL, - UNIQUE (user_id, observed_user_id) -); - -CREATE INDEX presence_list_user_id ON presence_list (user_id); +-- We used to create a table called presence_list, but this is no longer used +-- and is removed in delta 54. \ No newline at end of file diff --git a/synapse/storage/search.py b/synapse/storage/search.py index c6420b2374..226f8f1b7e 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -30,10 +30,10 @@ from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -SearchEntry = namedtuple('SearchEntry', [ - 'key', 'value', 'event_id', 'room_id', 'stream_ordering', - 'origin_server_ts', -]) +SearchEntry = namedtuple( + 'SearchEntry', + ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'], +) class SearchStore(BackgroundUpdateStore): @@ -53,8 +53,7 @@ class SearchStore(BackgroundUpdateStore): self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) self.register_background_update_handler( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_reindex_search_order + self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) # we used to have a background update to turn the GIN index into a @@ -62,13 +61,10 @@ class SearchStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update( - self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, - ) + self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) self.register_background_update_handler( - self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, - self._background_reindex_gin_search + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @defer.inlineCallbacks @@ -138,21 +134,23 @@ class SearchStore(BackgroundUpdateStore): # then skip over it continue - event_search_rows.append(SearchEntry( - key=key, - value=value, - event_id=event_id, - room_id=room_id, - stream_ordering=stream_ordering, - origin_server_ts=origin_server_ts, - )) + event_search_rows.append( + SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + ) + ) self.store_search_entries_txn(txn, event_search_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(event_search_rows) + "rows_inserted": rows_inserted + len(event_search_rows), } self._background_update_progress_txn( @@ -191,6 +189,7 @@ class SearchStore(BackgroundUpdateStore): # doesn't support CREATE INDEX IF EXISTS so we just catch the # exception and ignore it. import psycopg2 + try: c.execute( "CREATE INDEX CONCURRENTLY event_search_fts_idx" @@ -198,14 +197,11 @@ class SearchStore(BackgroundUpdateStore): ) except psycopg2.ProgrammingError as e: logger.warn( - "Ignoring error %r when trying to switch from GIST to GIN", - e + "Ignoring error %r when trying to switch from GIST to GIN", e ) # we should now be able to delete the GIST index. - c.execute( - "DROP INDEX IF EXISTS event_search_fts_idx_gist" - ) + c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") finally: conn.set_session(autocommit=False) @@ -223,6 +219,7 @@ class SearchStore(BackgroundUpdateStore): have_added_index = progress['have_added_indexes'] if not have_added_index: + def create_index(conn): conn.rollback() conn.set_session(autocommit=True) @@ -248,7 +245,8 @@ class SearchStore(BackgroundUpdateStore): yield self.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_update_progress_txn, - self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + pg, ) def reindex_search_txn(txn): @@ -302,14 +300,16 @@ class SearchStore(BackgroundUpdateStore): """ self.store_search_entries_txn( txn, - (SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ),), + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), ) def store_search_entries_txn(self, txn, entries): @@ -329,10 +329,17 @@ class SearchStore(BackgroundUpdateStore): " VALUES (?,?,?,to_tsvector('english', ?),?,?)" ) - args = (( - entry.event_id, entry.room_id, entry.key, entry.value, - entry.stream_ordering, entry.origin_server_ts, - ) for entry in entries) + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) # inserts to a GIN index are normally batched up into a pending # list, and then all committed together once the list gets to a @@ -363,9 +370,10 @@ class SearchStore(BackgroundUpdateStore): "INSERT INTO event_search (event_id, room_id, key, value)" " VALUES (?,?,?,?)" ) - args = (( - entry.event_id, entry.room_id, entry.key, entry.value, - ) for entry in entries) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) txn.executemany(sql, args) else: @@ -394,9 +402,7 @@ class SearchStore(BackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append( - "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) - ) + clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) args.extend(room_ids) local_clauses = [] @@ -404,9 +410,7 @@ class SearchStore(BackgroundUpdateStore): local_clauses.append("key = ?") args.append(key) - clauses.append( - "(%s)" % (" OR ".join(local_clauses),) - ) + clauses.append("(%s)" % (" OR ".join(local_clauses),)) count_args = args count_clauses = clauses @@ -452,18 +456,13 @@ class SearchStore(BackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *args - ) + results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) events = yield self._get_events([r["event_id"] for r in results]) - event_map = { - ev.event_id: ev - for ev in events - } + event_map = {ev.event_id: ev for ev in events} highlights = None if isinstance(self.database_engine, PostgresEngine): @@ -477,18 +476,17 @@ class SearchStore(BackgroundUpdateStore): count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - defer.returnValue({ - "results": [ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - } - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - }) + defer.returnValue( + { + "results": [ + {"event": event_map[r["event_id"]], "rank": r["rank"]} + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + ) @defer.inlineCallbacks def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): @@ -513,9 +511,7 @@ class SearchStore(BackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append( - "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) - ) + clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) args.extend(room_ids) local_clauses = [] @@ -523,9 +519,7 @@ class SearchStore(BackgroundUpdateStore): local_clauses.append("key = ?") args.append(key) - clauses.append( - "(%s)" % (" OR ".join(local_clauses),) - ) + clauses.append("(%s)" % (" OR ".join(local_clauses),)) # take copies of the current args and clauses lists, before adding # pagination clauses to main query. @@ -607,18 +601,13 @@ class SearchStore(BackgroundUpdateStore): args.append(limit) - results = yield self._execute( - "search_rooms", self.cursor_to_dict, sql, *args - ) + results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) events = yield self._get_events([r["event_id"] for r in results]) - event_map = { - ev.event_id: ev - for ev in events - } + event_map = {ev.event_id: ev for ev in events} highlights = None if isinstance(self.database_engine, PostgresEngine): @@ -632,21 +621,22 @@ class SearchStore(BackgroundUpdateStore): count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - defer.returnValue({ - "results": [ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" % ( - r["origin_server_ts"], r["stream_ordering"] - ), - } - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - }) + defer.returnValue( + { + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" + % (r["origin_server_ts"], r["stream_ordering"]), + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + ) def _find_highlights_in_postgres(self, search_query, events): """Given a list of events and a search term, return a list of words @@ -662,6 +652,7 @@ class SearchStore(BackgroundUpdateStore): Returns: deferred : A set of strings. """ + def f(txn): highlight_words = set() for event in events: @@ -689,13 +680,15 @@ class SearchStore(BackgroundUpdateStore): stop_sel += ">" query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( - _to_postgres_options({ - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - }) + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) ) - txn.execute(query, (value, search_query,)) + txn.execute(query, (value, search_query)) headline, = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline @@ -714,9 +707,7 @@ class SearchStore(BackgroundUpdateStore): def _to_postgres_options(options_dict): - return "'%s'" % ( - ",".join("%s=%s" % (k, v) for k, v in options_dict.items()), - ) + return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) def _parse_query(database_engine, search_term): diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 158e9dbe7b..6bd81e84ad 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -39,8 +39,9 @@ class SignatureWorkerStore(SQLBaseStore): # to use its cache raise NotImplementedError() - @cachedList(cached_method_name="get_event_reference_hash", - list_name="event_ids", num_args=1) + @cachedList( + cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 + ) def get_event_reference_hashes(self, event_ids): def f(txn): return { @@ -48,21 +49,13 @@ class SignatureWorkerStore(SQLBaseStore): for event_id in event_ids } - return self.runInteraction( - "get_event_reference_hashes", - f - ) + return self.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes( - event_ids - ) + hashes = yield self.get_event_reference_hashes(event_ids) hashes = { - e_id: { - k: encode_base64(v) for k, v in h.items() - if k == "sha256" - } + e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() } @@ -81,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore): " FROM event_reference_hashes" " WHERE event_id = ?" ) - txn.execute(query, (event_id, )) + txn.execute(query, (event_id,)) return {k: v for k, v in txn} @@ -98,14 +91,12 @@ class SignatureStore(SignatureWorkerStore): vals = [] for event in events: ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append({ - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": db_binary_type(ref_hash_bytes), - }) - - self._simple_insert_many_txn( - txn, - table="event_reference_hashes", - values=vals, - ) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": db_binary_type(ref_hash_bytes), + } + ) + + self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6ddc4055d2..0bfe1b4550 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -40,10 +40,13 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): """Return type of get_state_group_delta that implements __len__, which lets us use the itrable flag when caching """ + __slots__ = [] def __len__(self): @@ -70,10 +73,7 @@ class StateFilter(object): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = { - k: v for k, v in iteritems(self.types) - if v is not None - } + self.types = {k: v for k, v in iteritems(self.types) if v is not None} @staticmethod def all(): @@ -130,10 +130,7 @@ class StateFilter(object): Returns: StateFilter """ - return StateFilter( - types={EventTypes.Member: set(members)}, - include_others=True, - ) + return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) def return_expanded(self): """Creates a new StateFilter where type wild cards have been removed @@ -243,9 +240,7 @@ class StateFilter(object): if where_clause: where_clause += " OR " - where_clause += "type NOT IN (%s)" % ( - ",".join(["?"] * len(self.types)), - ) + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) where_args.extend(self.types) return where_clause, where_args @@ -305,12 +300,8 @@ class StateFilter(object): bool """ - return ( - self.include_others - or any( - state_keys is None - for state_keys in itervalues(self.types) - ) + return self.include_others or any( + state_keys is None for state_keys in itervalues(self.types) ) def concrete_types(self): @@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache") + 50000 * get_cache_factor_for("stateGroupCache"), ) self._state_group_members_cache = DictionaryCache( "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache") + 500000 * get_cache_factor_for("stateGroupMembersCache"), ) @defer.inlineCallbacks @@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: deferred: dict of (type, state_key) -> event_id """ + def _get_current_state_ids_txn(txn): txn.execute( """SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? """, - (room_id,) + (room_id,), ) return { (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction( - "get_current_state_ids", - _get_current_state_ids_txn, - ) + return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results return self.runInteraction( - "get_filtered_current_state_ids", - _get_filtered_current_state_ids_txn, + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @defer.inlineCallbacks @@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[str|None]: The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types( - [(EventTypes.CanonicalAlias, "")] - )) + state = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) event_id = state.get((EventTypes.CanonicalAlias, "")) if not event_id: @@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: (prev_group, delta_ids), where both may be None. """ + def _get_state_group_delta_txn(txn): prev_group = self._simple_select_one_onecol_txn( txn, table="state_group_edges", - keyvalues={ - "state_group": state_group, - }, + keyvalues={"state_group": state_group}, retcol="prev_state_group", allow_none=True, ) @@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): delta_ids = self._simple_select_list_txn( txn, table="state_groups_state", - keyvalues={ - "state_group": state_group, - }, - retcols=("type", "state_key", "event_id",) + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), ) - return _GetStateGroupDelta(prev_group, { - (row["type"], row["state_key"]): row["event_id"] - for row in delta_ids - }) - return self.runInteraction( - "get_state_group_delta", - _get_state_group_delta_txn, - ) + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_ids: defer.returnValue({}) - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) @@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in itervalues(group_to_ids) + ev_id + for group_ids in itervalues(group_to_ids) for ev_id in itervalues(group_ids) ], - get_prev_content=False + get_prev_content=False, ) - defer.returnValue({ - group: [ - state_event_map[v] for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - }) + defer.returnValue( + { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + ) @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, state_filter): @@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ results = {} - chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)] + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, state_filter, + self._get_state_groups_from_groups_txn, + chunk, + state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all(), + self, txn, groups, state_filter=StateFilter.all() ): results = {group: {} for group in groups} @@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" " WHERE state_group = ? " + where_clause, - args + args, ) results[group].update( ((typ, state_key), event_id) @@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - max_entries_returned is not None and - len(results[group]) == max_entries_returned + max_entries_returned is not None + and len(results[group]) == max_entries_returned ): break @@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False + get_prev_content=False, ) event_to_state = { @@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self._get_state_group_for_events( - event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, state_filter) @@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_state_group_for_event(self, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, retcol="state_group", allow_none=True, desc="_get_state_group_for_event", ) - @cachedList(cached_method_name="_get_state_group_for_event", - list_name="event_ids", num_args=1, inlineCallbacks=True) + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ @@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): column="event_id", iterable=event_ids, keyvalues={}, - retcols=("event_id", "state_group",), + retcols=("event_id", "state_group"), desc="_get_state_group_for_events", ) @@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Now we look them up in the member and non-member caches non_member_state, incomplete_groups_nm, = ( yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, - state_filter=non_member_filter, + groups, self._state_group_cache, state_filter=non_member_filter ) ) member_state, incomplete_groups_m, = ( yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, - state_filter=member_filter, + groups, self._state_group_members_cache, state_filter=member_filter ) ) @@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): db_state_filter = state_filter.return_expanded() group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), - state_filter=db_state_filter, + list(incomplete_groups), state_filter=db_state_filter ) # Now lets update the caches @@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue(state) - def _get_state_for_groups_using_cache( - self, groups, cache, state_filter, - ): + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results, incomplete_groups - def _insert_into_cache(self, group_to_state_dict, state_filter, - cache_seq_num_members, cache_seq_num_non_members): + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): """Inserts results from querying the database into the relevant cache. Args: @@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): fetched_keys=non_member_types, ) - def store_state_group(self, event_id, room_id, prev_group, delta_ids, - current_state_ids): + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): """Store a new set of state, returning a newly assigned state group. Args: @@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Deferred[int]: The state group ID """ + def _store_state_group_txn(txn): if current_state_ids is None: # AFAIK, this can never happen @@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): self._simple_insert_txn( txn, table="state_groups", - values={ - "id": state_group, - "room_id": room_id, - "event_id": event_id, - }, + values={"id": state_group, "room_id": room_id, "event_id": event_id}, ) # We persist as a delta if we can, while also ensuring the chain @@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): % (prev_group,) ) - potential_hops = self._count_state_group_hops_txn( - txn, prev_group - ) + potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: self._simple_insert_txn( txn, table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, + values={"state_group": state_group, "prev_state_group": prev_group}, ) self._simple_insert_many_txn( @@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): This is used to ensure the delta chains don't get too long. """ if isinstance(self.database_engine, PostgresEngine): - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL @@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): WHERE s.state_group = e.state_group ) SELECT count(*) FROM state; - """) + """ txn.execute(sql, (state_group,)) row = txn.fetchone() @@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): self._background_deduplicate_state, ) self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) self.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, @@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn, table="event_to_state_groups", values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } + {"state_group": state_group_id, "event_id": event_id} for event_id, state_group_id in iteritems(state_groups) ], ) for event_id, state_group_id in iteritems(state_groups): txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id + self._get_state_group_for_event.prefill, (event_id,), state_group_id ) @defer.inlineCallbacks @@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): if max_group is None: rows = yield self._execute( - "_background_deduplicate_state", None, + "_background_deduplicate_state", + None, "SELECT coalesce(max(id), 0) FROM state_groups", ) max_group = rows[0][0] @@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): " WHERE ? < id AND id <= ?" " ORDER BY id ASC" " LIMIT 1", - (new_last_state_group, max_group,) + (new_last_state_group, max_group), ) row = txn.fetchone() if row: @@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn.execute( "SELECT state_group FROM state_group_edges" " WHERE state_group = ?", - (state_group,) + (state_group,), ) # If we reach a point where we've already started inserting @@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): txn.execute( "SELECT coalesce(max(id), 0) FROM state_groups" " WHERE id < ? AND room_id = ?", - (state_group, room_id,) + (state_group, room_id), ) prev_group, = txn.fetchone() new_last_state_group = state_group if prev_group: - potential_hops = self._count_state_group_hops_txn( - txn, prev_group - ) + potential_hops = self._count_state_group_hops_txn(txn, prev_group) if potential_hops >= MAX_STATE_DELTA_HOPS: # We want to ensure chains are at most this long,# # otherwise read performance degrades. continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], + txn, [prev_group] ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], + txn, [state_group] ) curr_state = curr_state[state_group] @@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): # of keys delta_state = { - key: value for key, value in iteritems(curr_state) + key: value + for key, value in iteritems(curr_state) if prev_state.get(key, None) != value } self._simple_delete_txn( txn, table="state_group_edges", - keyvalues={ - "state_group": state_group, - } + keyvalues={"state_group": state_group}, ) self._simple_insert_txn( @@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): values={ "state_group": state_group, "prev_state_group": prev_group, - } + }, ) self._simple_delete_txn( txn, table="state_groups_state", - keyvalues={ - "state_group": state_group, - } + keyvalues={"state_group": state_group}, ) self._simple_insert_many_txn( @@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): ) if finished: - yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) + yield self._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) @@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" ) - txn.execute( - "DROP INDEX IF EXISTS state_groups_state_id" - ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") finally: conn.set_session(autocommit=False) else: @@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "CREATE INDEX state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" ) - txn.execute( - "DROP INDEX IF EXISTS state_groups_state_id" - ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") yield self.runWithConnection(reindex_txn) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 57bc45cdb9..31a0279b18 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -21,10 +21,29 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id): + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id (int): point to get changes since (exclusive) + + Returns: + Deferred[list[dict]]: results + """ prev_stream_id = int(prev_stream_id) - if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + if not self._curr_state_delta_stream_cache.has_any_entity_changed( + prev_stream_id + ): return [] def get_current_state_deltas_txn(txn): @@ -58,7 +77,7 @@ class StateDeltasStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, max_stream_id,)) + txn.execute(sql, (prev_stream_id, max_stream_id)) return self.cursor_to_dict(txn) return self.runInteraction( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 580fafeb3a..9cd1e0f9fe 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple("_EventDictReturn", ( - "event_id", "topological_ordering", "stream_ordering", -)) +_EventDictReturn = namedtuple( + "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") +) def lower_bound(token, engine, inclusive=False): @@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False): # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we # use the later form when running against postgres. return "((%d,%d) <%s (%s,%s))" % ( - token.topological, token.stream, inclusive, - "topological_ordering", "stream_ordering", + token.topological, + token.stream, + inclusive, + "topological_ordering", + "stream_ordering", ) return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, "topological_ordering", - token.topological, "topological_ordering", - token.stream, inclusive, "stream_ordering", + token.topological, + "topological_ordering", + token.topological, + "topological_ordering", + token.stream, + inclusive, + "stream_ordering", ) @@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True): # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we # use the later form when running against postgres. return "((%d,%d) >%s (%s,%s))" % ( - token.topological, token.stream, inclusive, - "topological_ordering", "stream_ordering", + token.topological, + token.stream, + inclusive, + "topological_ordering", + "stream_ordering", ) return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, "topological_ordering", - token.topological, "topological_ordering", - token.stream, inclusive, "stream_ordering", + token.topological, + "topological_ordering", + token.topological, + "topological_ordering", + token.stream, + inclusive, + "stream_ordering", ) @@ -116,9 +130,7 @@ def filter_to_clause(event_filter): args = [] if event_filter.types: - clauses.append( - "(%s)" % " OR ".join("type = ?" for _ in event_filter.types) - ) + clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types)) args.extend(event_filter.types) for typ in event_filter.not_types: @@ -126,9 +138,7 @@ def filter_to_clause(event_filter): args.append(typ) if event_filter.senders: - clauses.append( - "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders) - ) + clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)) args.extend(event_filter.senders) for sender in event_filter.not_senders: @@ -136,9 +146,7 @@ def filter_to_clause(event_filter): args.append(sender) if event_filter.rooms: - clauses.append( - "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms) - ) + clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)) args.extend(event_filter.rooms) for room_id in event_filter.not_rooms: @@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", + db_conn, + "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, + "EventsRoomStreamChangeCache", + min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, + "MembershipStreamChangeCache", events_max ) self._stream_order_on_start = self.get_room_max_stream_ordering() @@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotImplementedError() @defer.inlineCallbacks - def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, - order='DESC'): + def get_room_events_stream_for_rooms( + self, room_ids, from_key, to_key, limit=0, order='DESC' + ): """Get new room events in stream ordering since `from_key`. Args: @@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable(defer.gatherResults([ - run_in_background( - self.get_room_events_stream_for_room, - room_id, from_key, to_key, limit, order=order, + for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): + res = yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.get_room_events_stream_for_room, + room_id, + from_key, + to_key, + limit, + order=order, + ) + for room_id in rm_ids + ], + consumeErrors=True, ) - for room_id in rm_ids - ], consumeErrors=True)) + ) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) @@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( - room_id for room_id in room_ids + room_id + for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) ) @defer.inlineCallbacks - def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, - order='DESC'): + def get_room_events_stream_for_room( + self, room_id, from_key, to_key, limit=0, order='DESC' + ): """Get new room events in stream ordering since `from_key`. @@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_room_events_stream_for_room", f) - ret = yield self._get_events( - [r.event_id for r in rows], - get_prev_content=True - ) + ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(ret, rows, topo_order=from_id is None) @@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) - txn.execute(sql, (user_id, from_id, to_id,)) + txn.execute(sql, (user_id, from_id, to_id)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] @@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_membership_changes_for_user", f) - ret = yield self._get_events( - [r.event_id for r in rows], - get_prev_content=True - ) + ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(ret, rows, topo_order=False) @@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ rows, token = yield self.get_recent_event_ids_for_room( - room_id, limit, end_token, + room_id, limit, end_token ) logger.debug("stream before") events = yield self._get_events( - [r.event_id for r in rows], - get_prev_content=True + [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") @@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) rows, token = yield self.runInteraction( - "get_recent_event_ids_for_room", self._paginate_room_events_txn, - room_id, from_token=end_token, limit=limit, + "get_recent_event_ids_for_room", + self._paginate_room_events_txn, + room_id, + from_token=end_token, + limit=limit, ) # We want to return the results in ascending order. @@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ + def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" @@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " ORDER BY stream_ordering" " LIMIT 1" ) - txn.execute(sql, (room_id, stream_ordering, )) + txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction( - "get_room_event_after_stream_ordering", _f, - ) + return self.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn, - room_id, + "_get_max_topological_txn", self._get_max_topological_txn, room_id ) defer.returnValue("t%d-%d" % (topo, token)) @@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): A deferred "s%d" stream token. """ return self._simple_select_one_onecol( - table="events", - keyvalues={"event_id": event_id}, - retcol="stream_ordering", + table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) def get_topological_token_for_event(self, event_id): @@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback(lambda row: "t%d-%d" % ( - row["topological_ordering"], row["stream_ordering"],) + ).addCallback( + lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) def get_max_topological_token(self, room_id, stream_key): @@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( - "get_max_topological_token", None, - sql, room_id, stream_key, - ).addCallback( - lambda r: r[0][0] if r else 0 - ) + "get_max_topological_token", None, sql, room_id, stream_key + ).addCallback(lambda r: r[0][0] if r else 0) def _get_max_topological_txn(self, txn, room_id): txn.execute( - "SELECT MAX(topological_ordering) FROM events" - " WHERE room_id = ?", - (room_id,) + "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", + (room_id,), ) rows = txn.fetchall() @@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) - internal.order = ( - int(topo) if topo else 0, - int(stream), - ) + internal.order = (int(topo) if topo else 0, int(stream)) @defer.inlineCallbacks def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None, + self, room_id, event_id, before_limit, after_limit, event_filter=None ): """Retrieve events and pagination tokens around a given event in a room. @@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ results = yield self.runInteraction( - "get_events_around", self._get_events_around_txn, - room_id, event_id, before_limit, after_limit, event_filter, + "get_events_around", + self._get_events_around_txn, + room_id, + event_id, + before_limit, + after_limit, + event_filter, ) events_before = yield self._get_events( - [e for e in results["before"]["event_ids"]], - get_prev_content=True + [e for e in results["before"]["event_ids"]], get_prev_content=True ) events_after = yield self._get_events( - [e for e in results["after"]["event_ids"]], - get_prev_content=True + [e for e in results["after"]["event_ids"]], get_prev_content=True ) - defer.returnValue({ - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - }) + defer.returnValue( + { + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + } + ) def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter, + self, txn, room_id, event_id, before_limit, after_limit, event_filter ): """Retrieves event_ids and pagination tokens around a given event in a room. @@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = self._simple_select_one_txn( txn, "events", - keyvalues={ - "event_id": event_id, - "room_id": room_id, - }, + keyvalues={"event_id": event_id, "room_id": room_id}, retcols=["stream_ordering", "topological_ordering"], ) # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - results["topological_ordering"] - 1, - results["stream_ordering"], + results["topological_ordering"] - 1, results["stream_ordering"] ) after_token = RoomStreamToken( - results["topological_ordering"], - results["stream_ordering"], + results["topological_ordering"], results["stream_ordering"] ) rows, start_token = self._paginate_room_events_txn( - txn, room_id, before_token, direction='b', limit=before_limit, + txn, + room_id, + before_token, + direction='b', + limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( - txn, room_id, after_token, direction='f', limit=after_limit, + txn, + room_id, + after_token, + direction='f', + limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { - "before": { - "event_ids": events_before, - "token": start_token, - }, - "after": { - "event_ids": events_after, - "token": end_token, - }, + "before": {"event_ids": events_before, "token": start_token}, + "after": {"event_ids": events_after, "token": end_token}, } @defer.inlineCallbacks @@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn, + "get_all_new_events_stream", get_all_new_events_stream_txn ) events = yield self._get_events(event_ids) @@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, - desc="get_federation_out_pos" + desc="get_federation_out_pos", ) def update_federation_out_pos(self, typ, stream_id): @@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) - def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None, - direction='b', limit=-1, event_filter=None): + def _paginate_room_events_txn( + self, + txn, + room_id, + from_token, + to_token=None, + direction='b', + limit=-1, + event_filter=None, + ): """Returns list of events before or after a given token. Args: @@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = upper_bound( - from_token, self.database_engine - ) + bounds = upper_bound(from_token, self.database_engine) if to_token: - bounds = "%s AND %s" % (bounds, lower_bound( - to_token, self.database_engine - )) + bounds = "%s AND %s" % ( + bounds, + lower_bound(to_token, self.database_engine), + ) else: order = "ASC" - bounds = lower_bound( - from_token, self.database_engine - ) + bounds = lower_bound(from_token, self.database_engine) if to_token: - bounds = "%s AND %s" % (bounds, upper_bound( - to_token, self.database_engine - )) + bounds = "%s AND %s" % ( + bounds, + upper_bound(to_token, self.database_engine), + ) filter_clause, filter_args = filter_to_clause(event_filter) @@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" - ) % { - "bounds": bounds, - "order": order, - } + ) % {"bounds": bounds, "order": order} txn.execute(sql, args) @@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token - return rows, str(next_token), + return rows, str(next_token) @defer.inlineCallbacks - def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, event_filter=None): + def paginate_room_events( + self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None + ): """Returns list of events before or after a given token. Args: @@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_key = RoomStreamToken.parse(to_key) rows, token = yield self.runInteraction( - "paginate_room_events", self._paginate_room_events_txn, - room_id, from_key, to_key, direction, limit, event_filter, + "paginate_room_events", + self._paginate_room_events_txn, + room_id, + from_key, + to_key, + direction, + limit, + event_filter, ) events = yield self._get_events( - [r.event_id for r in rows], - get_prev_content=True + [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 0f657b2bd3..e88f8ea35f 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -84,9 +84,7 @@ class TagsWorkerStore(AccountDataWorkerStore): def get_tag_content(txn, tag_ids): sql = ( - "SELECT tag, content" - " FROM room_tags" - " WHERE user_id=? AND room_id=?" + "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?" ) results = [] for stream_id, user_id, room_id in tag_ids: @@ -105,7 +103,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tags = yield self.runInteraction( "get_all_updated_tag_content", get_tag_content, - tag_ids[i:i + batch_size], + tag_ids[i : i + batch_size], ) results.extend(tags) @@ -123,6 +121,7 @@ class TagsWorkerStore(AccountDataWorkerStore): A deferred dict mapping from room_id strings to lists of tag strings for all the rooms that changed since the stream_id token. """ + def get_updated_tags_txn(txn): sql = ( "SELECT room_id from room_tags_revisions" @@ -138,9 +137,7 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: defer.returnValue({}) - room_ids = yield self.runInteraction( - "get_updated_tags", get_updated_tags_txn - ) + room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) results = {} if room_ids: @@ -163,9 +160,9 @@ class TagsWorkerStore(AccountDataWorkerStore): keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback(lambda rows: { - row["tag"]: json.loads(row["content"]) for row in rows - }) + ).addCallback( + lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} + ) class TagsStore(TagsWorkerStore): @@ -186,14 +183,8 @@ class TagsStore(TagsWorkerStore): self._simple_upsert_txn( txn, table="room_tags", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "tag": tag, - }, - values={ - "content": content_json, - } + keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, + values={"content": content_json}, ) self._update_revision_txn(txn, user_id, room_id, next_id) @@ -211,6 +202,7 @@ class TagsStore(TagsWorkerStore): Returns: A deferred that completes once the tag has been removed """ + def remove_tag_txn(txn, next_id): sql = ( "DELETE FROM room_tags " @@ -238,8 +230,7 @@ class TagsStore(TagsWorkerStore): """ txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id + self._account_data_stream_cache.entity_has_changed, user_id, next_id ) update_max_id_sql = ( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index d8bf953ec0..b1188f6bcb 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -38,16 +38,12 @@ logger = logging.getLogger(__name__) _TransactionRow = namedtuple( - "_TransactionRow", ( - "id", "transaction_id", "destination", "ts", "response_code", - "response_json", - ) + "_TransactionRow", + ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), ) _UpdateTransactionRow = namedtuple( - "_TransactionRow", ( - "response_code", "response_json", - ) + "_TransactionRow", ("response_code", "response_json") ) SENTINEL = object() @@ -84,19 +80,22 @@ class TransactionStore(SQLBaseStore): return self.runInteraction( "get_received_txn_response", - self._get_received_txn_response, transaction_id, origin + self._get_received_txn_response, + transaction_id, + origin, ) def _get_received_txn_response(self, txn, transaction_id, origin): result = self._simple_select_one_txn( txn, table="received_transactions", - keyvalues={ - "transaction_id": transaction_id, - "origin": origin, - }, + keyvalues={"transaction_id": transaction_id, "origin": origin}, retcols=( - "transaction_id", "origin", "ts", "response_code", "response_json", + "transaction_id", + "origin", + "ts", + "response_code", + "response_json", "has_been_referenced", ), allow_none=True, @@ -108,8 +107,7 @@ class TransactionStore(SQLBaseStore): else: return None - def set_received_txn_response(self, transaction_id, origin, code, - response_dict): + def set_received_txn_response(self, transaction_id, origin, code, response_dict): """Persist the response we returened for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. @@ -135,8 +133,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - def prep_send_transaction(self, transaction_id, destination, - origin_server_ts): + def prep_send_transaction(self, transaction_id, destination, origin_server_ts): """Persists an outgoing transaction and calculates the values for the previous transaction id list. @@ -182,7 +179,9 @@ class TransactionStore(SQLBaseStore): result = yield self.runInteraction( "get_destination_retry_timings", - self._get_destination_retry_timings, destination) + self._get_destination_retry_timings, + destination, + ) # We don't hugely care about race conditions between getting and # invalidating the cache, since we time out fairly quickly anyway. @@ -193,9 +192,7 @@ class TransactionStore(SQLBaseStore): result = self._simple_select_one_txn( txn, table="destinations", - keyvalues={ - "destination": destination, - }, + keyvalues={"destination": destination}, retcols=("destination", "retry_last_ts", "retry_interval"), allow_none=True, ) @@ -205,8 +202,7 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings(self, destination, - retry_last_ts, retry_interval): + def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval): """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. @@ -225,8 +221,9 @@ class TransactionStore(SQLBaseStore): retry_interval, ) - def _set_destination_retry_timings(self, txn, destination, - retry_last_ts, retry_interval): + def _set_destination_retry_timings( + self, txn, destination, retry_last_ts, retry_interval + ): self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -235,9 +232,7 @@ class TransactionStore(SQLBaseStore): prev_row = self._simple_select_one_txn( txn, table="destinations", - keyvalues={ - "destination": destination, - }, + keyvalues={"destination": destination}, retcols=("retry_last_ts", "retry_interval"), allow_none=True, ) @@ -250,15 +245,13 @@ class TransactionStore(SQLBaseStore): "destination": destination, "retry_last_ts": retry_last_ts, "retry_interval": retry_interval, - } + }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: self._simple_update_one_txn( txn, "destinations", - keyvalues={ - "destination": destination, - }, + keyvalues={"destination": destination}, updatevalues={ "retry_last_ts": retry_last_ts, "retry_interval": retry_interval, @@ -273,8 +266,7 @@ class TransactionStore(SQLBaseStore): """ return self.runInteraction( - "get_destinations_needing_retry", - self._get_destinations_needing_retry + "get_destinations_needing_retry", self._get_destinations_needing_retry ) def _get_destinations_needing_retry(self, txn): @@ -288,7 +280,7 @@ class TransactionStore(SQLBaseStore): def _start_cleanup_transactions(self): return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions, + "cleanup_transactions", self._cleanup_transactions ) def _cleanup_transactions(self): diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 4d60a5726f..83466e25d9 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -194,7 +194,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): room_id ) - users_with_profile = yield state.get_current_user_in_room(room_id) + users_with_profile = yield state.get_current_users_in_room(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py index be013f4427..1815fdc0dd 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/user_erasure_store.py @@ -40,9 +40,7 @@ class UserErasureWorkerStore(SQLBaseStore): ).addCallback(operator.truth) @cachedList( - cached_method_name="is_user_erased", - list_name="user_ids", - inlineCallbacks=True, + cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True ) def are_users_erased(self, user_ids): """ @@ -61,16 +59,13 @@ class UserErasureWorkerStore(SQLBaseStore): def _get_erased_users(txn): txn.execute( - "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % ( - ",".join("?" * len(user_ids)) - ), + "SELECT user_id FROM erased_users WHERE user_id IN (%s)" + % (",".join("?" * len(user_ids))), user_ids, ) return set(r[0] for r in txn) - erased_users = yield self.runInteraction( - "are_users_erased", _get_erased_users, - ) + erased_users = yield self.runInteraction("are_users_erased", _get_erased_users) res = dict((u, u in erased_users) for u in user_ids) defer.returnValue(res) @@ -82,22 +77,16 @@ class UserErasureStore(UserErasureWorkerStore): Args: user_id (str): full user_id to be erased """ + def f(txn): # first check if they are already in the list - txn.execute( - "SELECT 1 FROM erased_users WHERE user_id = ?", - (user_id, ) - ) + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) if txn.fetchone(): return # they are not already there: do the insert. - txn.execute( - "INSERT INTO erased_users (user_id) VALUES (?)", - (user_id, ) - ) + txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,)) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - self._invalidate_cache_and_stream( - txn, self.is_user_erased, (user_id,) - ) return self.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d6160d5e4d..f1c8d99419 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -43,9 +43,9 @@ def _load_current_id(db_conn, table, column, step=1): """ cur = db_conn.cursor() if step == 1: - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: - cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) + cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) val, = cur.fetchone() cur.close() current_id = int(val) if val else step @@ -77,6 +77,7 @@ class StreamIdGenerator(object): with stream_id_gen.get_next() as stream_id: # ... persist event ... """ + def __init__(self, db_conn, table, column, extra_tables=[], step=1): assert step != 0 self._lock = threading.Lock() @@ -84,8 +85,7 @@ class StreamIdGenerator(object): self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: self._current = (max if step > 0 else min)( - self._current, - _load_current_id(db_conn, table, column, step) + self._current, _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -121,7 +121,7 @@ class StreamIdGenerator(object): next_ids = range( self._current + self._step, self._current + self._step * (n + 1), - self._step + self._step, ) self._current += n * self._step diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f0e4a0e10c..2f16f23d91 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 9cb7e9c9ab..628a2962d9 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +import traceback from twisted.conch import manhole_ssh from twisted.conch.insults import insults -from twisted.conch.manhole import ColoredManhole +from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal @@ -79,7 +82,7 @@ def manhole(username, password, globals): rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - ColoredManhole, + SynapseManhole, dict(globals, __name__="__console__") ) @@ -88,3 +91,55 @@ def manhole(username, password, globals): factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) return factory + + +class SynapseManhole(ColoredManhole): + """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): + super(SynapseManhole, self).connectionMade() + + # replace the manhole interpreter with our own impl + self.interpreter = SynapseManholeInterpreter(self, self.namespace) + + # this would also be a good place to add more keyHandlers. + + +class SynapseManholeInterpreter(ManholeInterpreter): + def showsyntaxerror(self, filename=None): + """Display the syntax error that just occurred. + + Overrides the base implementation, ignoring sys.excepthook. We always want + any syntax errors to be sent to the terminal, rather than sentry. + """ + type, value, tb = sys.exc_info() + sys.last_type = type + sys.last_value = value + sys.last_traceback = tb + if filename and type is SyntaxError: + # Work hard to stuff the correct filename in the exception + try: + msg, (dummy_filename, lineno, offset, line) = value.args + except ValueError: + # Not the format we expect; leave it alone + pass + else: + # Stuff in the right filename + value = SyntaxError(msg, (filename, lineno, offset, line)) + sys.last_value = value + lines = traceback.format_exception_only(type, value) + self.write(''.join(lines)) + + def showtraceback(self): + """Display the exception that just occurred. + + Overrides the base implementation, ignoring sys.excepthook. We always want + any syntax errors to be sent to the terminal, rather than sentry. + """ + sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() + sys.last_traceback = last_tb + try: + # We remove the first stack item because it is our own code. + lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) + self.write(''.join(lines)) + finally: + last_tb = ei = None |