diff options
Diffstat (limited to 'synapse')
97 files changed, 2603 insertions, 2030 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c77..abbc7079a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9c96816096..0e8b467a3e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -237,6 +237,12 @@ def start(hs, listeners=None): """ Start a Synapse server or worker. + Should be called once the reactor is running and (if we're using ACME) the + TLS certificates are in place. + + Will start the main HTTP listeners and do some other startup tasks, and then + notify systemd. + Args: hs (synapse.server.HomeServer) listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) @@ -311,9 +317,7 @@ def setup_sdnotify(hs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. - hs.get_reactor().addSystemEventTrigger( - "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),) - ) + sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),)) hs.get_reactor().addSystemEventTrigger( "before", "shutdown", sdnotify, b"STOPPING=1" diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..8e36bc57d3 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -105,8 +104,10 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) + res = yield defer.ensureDeferred( + hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) + ) ) print(res) @@ -229,14 +230,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 02b900f382..e82e0f11e3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -143,8 +142,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -159,10 +156,8 @@ def start(config_options): ps = AppserviceServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dadb487d5f..3edfe19567 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -181,14 +180,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = ClientReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d110599a35..d0ddbe38fc 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -180,14 +179,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = EventCreatorServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 418c086254..311523e0ed 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -162,14 +161,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FederationReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f24920a7d6..83c436229c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -174,8 +173,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" @@ -190,10 +187,8 @@ def start(config_options): ss = FederationSenderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e647459d0e..30e435eead 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -234,14 +233,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FrontendProxyServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..0e9bf7f53a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree @@ -328,15 +328,10 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) @@ -347,13 +342,8 @@ def setup(config_options): hs.setup() except IncorrectDatabaseSetup as e: quit_with_error(str(e)) - except UpgradeDatabaseException: - sys.stderr.write( - "\nFailed to upgrade database.\n" - "Have you checked for version specific instructions in" - " UPGRADES.rst?\n" - ) - sys.exit(1) + except UpgradeDatabaseException as e: + quit_with_error("Failed to upgrade database: %s" % (e,)) hs.setup_master() @@ -519,8 +509,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.database_engine.module.__name__ - stats["database_server_version"] = hs.database_engine.server_version + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2c6dd3ef02..4c80f257e2 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -157,14 +156,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = MediaRepositoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index dd52a9fc2d..09e639040a 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -203,14 +202,10 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True - database_engine = create_engine(config.database_config) - ps = PusherServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 288ee64b42..dd2132e608 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -437,14 +436,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = SynchrotronServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, application_service_handler=SynchrotronApplicationService(), ) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index c01fb34a9b..1257098f92 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -200,8 +199,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -216,10 +213,8 @@ def start(config_options): ss = UserDirectoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..134824789c 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. Defaults to all data stores. + """ + + def __init__( + self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"] + ): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +59,12 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 18f42a87f9..35756bed87 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -21,6 +21,7 @@ from __future__ import print_function import email.utils import os from enum import Enum +from typing import Optional import pkg_resources @@ -101,7 +102,7 @@ class EmailConfig(Config): # both in RegistrationConfig and here. We should factor this bit out self.account_threepid_delegate_email = self.trusted_third_party_id_servers[ 0 - ] + ] # type: Optional[str] self.using_identity_server_from_trusted_list = True else: raise ConfigError( diff --git a/synapse/config/key.py b/synapse/config/key.py index 52ff1b2621..066e7838c3 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -108,7 +108,7 @@ class KeyConfig(Config): self.signing_key = self.read_signing_keys(signing_key_path, "signing_key") self.old_signing_keys = self.read_old_signing_keys( - config.get("old_signing_keys", {}) + config.get("old_signing_keys") ) self.key_refresh_interval = self.parse_duration( config.get("key_refresh_interval", "1d") @@ -199,14 +199,19 @@ class KeyConfig(Config): signing_key_path: "%(base_key_name)s.signing.key" # The keys that the server used to sign messages with but won't use - # to sign new messages. E.g. it has lost its private key + # to sign new messages. # - #old_signing_keys: - # "ed25519:auto": - # # Base64 encoded public key - # key: "The public part of your old signing key." - # # Millisecond POSIX timestamp when the key expired. - # expired_ts: 123456789123 + old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -290,6 +295,8 @@ class KeyConfig(Config): raise ConfigError("Error reading %s: %s" % (name, str(e))) def read_old_signing_keys(self, old_signing_keys): + if old_signing_keys is None: + return {} keys = {} for key_id, key_data in old_signing_keys.items(): if is_signing_algorithm_supported(key_id): diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 947f653e03..4a3bfc4354 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -83,10 +83,9 @@ class RatelimitConfig(Config): ) rc_admin_redaction = config.get("rc_admin_redaction") + self.rc_admin_redaction = None if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) - else: - self.rc_admin_redaction = None def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a1..b91414aa35 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ class SAML2Config(Config): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) - - def _default_saml_config_dict(self): + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ class SAML2Config(Config): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace - - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/config/server.py b/synapse/config/server.py index a4bef00936..38f6ff9edc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -102,6 +102,12 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) + # Whether to require sharing a room with a user to retrieve their + # profile data + self.limit_profile_requests_to_users_who_share_rooms = config.get( + "limit_profile_requests_to_users_who_share_rooms", False, + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -200,7 +206,7 @@ class ServerConfig(Config): self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None + self.federation_domain_whitelist = None # type: Optional[dict] federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: @@ -621,6 +627,13 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true + # Uncomment to require a user to share a room with another user in order + # to retrieve their profile information. Only checked on Client-Server + # requests. Profile requests from other servers should be checked by the + # requesting server. Defaults to 'false'. + # + #limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 350ed9351f..1033e5e121 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -43,6 +43,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -87,12 +89,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c..a44baea365 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -149,7 +149,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -167,12 +167,13 @@ class EventContext: } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: @@ -181,6 +182,7 @@ class EventContext: context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, prev_state_id=input["prev_state_id"], event_type=input["event_type"], event_state_key=input["event_state_key"], @@ -193,7 +195,7 @@ class EventContext: app_service_id = input["app_service_id"] if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + context.app_service = storage.main.get_app_service_by_id(app_service_id) return context @@ -216,7 +218,7 @@ class EventContext: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +236,11 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self): """ Gets the room state map, excluding this event. @@ -250,7 +252,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +272,7 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self): return defer.succeed(None) @@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext): Attributes: + _storage (Storage) + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet @@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext): that was replaced. """ + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) _prev_state_id = attr.ib(default=None) _event_type = attr.ib(default=None) _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self): if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) + self._fetching_state_deferred = run_in_background(self._fill_out_state) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..86f7e5f8aa 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d396e6564f..af652a7659 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -526,13 +526,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + content = yield self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) @@ -600,6 +594,44 @@ class FederationClient(FederationBase): return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks + def _do_send_join(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_join_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() + + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_join_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] + + @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): room_version = yield self.store.get_room_version(room_id) @@ -708,18 +740,50 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( + content = yield self._do_send_leave(destination, pdu) + + logger.debug("Got content: %s", content) + return None + + return self._try_destination_list("send_leave", destinations, send_request) + + @defer.inlineCallbacks + def _do_send_leave(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - logger.debug("Got content: %s", content) - return None + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() - return self._try_destination_list("send_leave", destinations, send_request) + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_leave_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] def get_public_rooms( self, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 84d4eca041..d7ce333822 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -384,15 +384,10 @@ class FederationServer(FederationBase): res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - return ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + } async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) @@ -419,7 +414,7 @@ class FederationServer(FederationBase): pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) - return 200, {} + return {} async def on_event_auth(self, origin, room_id, event_id): with (await self._server_linearizer.queue((origin, room_id))): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 46dba84cac..198257414b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -243,7 +243,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_join(self, destination, room_id, event_id, content): + def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -254,7 +254,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_leave(self, destination, room_id, event_id, content): + def send_join_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, path=path, data=content + ) + + return response + + @defer.inlineCallbacks + @log_function + def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -272,6 +283,24 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def send_leave_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, + ) + + return response + + @defer.inlineCallbacks + @log_function def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fefc789c85..b4cbf23394 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationSendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet): return await self.handler.on_event_auth(origin, context, event_id) -class FederationSendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PREFIX = FEDERATION_V2_PREFIX + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content @@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = ( FederationMakeJoinServlet, FederationMakeLeaveServlet, FederationEventServlet, - FederationSendJoinServlet, - FederationSendLeaveServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationQueryAuthServlet, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 29e8ffc295..0ec9be3cb5 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -773,6 +773,11 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(user_id): yield self.store.maybe_delete_remote_profile_cache(user_id) + # Delete group if the last user has left + users = yield self.store.get_users_in_group(group_id, include_private=True) + if not users: + yield self.store.delete_group(group_id) + return {} @defer.inlineCallbacks diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..51413d910e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -134,7 +134,7 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 2d7e6df6e4..a8d3fbc6de 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,15 +21,14 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( @@ -41,7 +38,7 @@ class AccountDataEventSource(object): ( account_data, room_account_data, - ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -53,7 +50,3 @@ class AccountDataEventSource(object): ) return results, current_stream_id - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - return [], config.to_id diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 14449b9a1e..1a4ba12385 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.types import RoomStreamToken from synapse.visibility import filter_events_for_client @@ -33,11 +31,10 @@ class AdminHandler(BaseHandler): self.storage = hs.get_storage() self.state_store = self.storage.state - @defer.inlineCallbacks - def get_whois(self, user): + async def get_whois(self, user): connections = [] - sessions = yield self.store.get_user_ip_and_agents(user) + sessions = await self.store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -54,20 +51,18 @@ class AdminHandler(BaseHandler): return ret - @defer.inlineCallbacks - def get_users(self): + async def get_users(self): """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.get_users() + ret = await self.store.get_users() return ret - @defer.inlineCallbacks - def get_users_paginate(self, start, limit, name, guests, deactivated): + async def get_users_paginate(self, start, limit, name, guests, deactivated): """Function to retrieve a paginated list of users from users list. This will return a json list of users. @@ -80,14 +75,13 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to json list[dict[str, Any]] """ - ret = yield self.store.get_users_paginate( + ret = await self.store.get_users_paginate( start, limit, name, guests, deactivated ) return ret - @defer.inlineCallbacks - def search_users(self, term): + async def search_users(self, term): """Function to search users list for one or more users with the matched term. @@ -96,7 +90,7 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.search_users(term) + ret = await self.store.search_users(term) return ret @@ -119,8 +113,7 @@ class AdminHandler(BaseHandler): """ return self.store.set_server_admin(user, admin) - @defer.inlineCallbacks - def export_user_data(self, user_id, writer): + async def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. Args: @@ -132,7 +125,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -145,7 +138,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those seperately. - rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -154,7 +147,7 @@ class AdminHandler(BaseHandler): "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = yield self.store.did_forget(user_id, room_id) + forgotten = await self.store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -166,7 +159,7 @@ class AdminHandler(BaseHandler): if room.membership == Membership.INVITE: event_id = room.event_id - invite = yield self.store.get_event(event_id, allow_none=True) + invite = await self.store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) @@ -177,7 +170,7 @@ class AdminHandler(BaseHandler): # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = yield self.store.get_room_max_stream_ordering() + stream_ordering = self.store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -203,7 +196,7 @@ class AdminHandler(BaseHandler): # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = yield self.store.paginate_room_events( + events, _ = await self.store.paginate_room_events( room_id, from_key, to_key, limit=100, direction="f" ) if not events: @@ -211,7 +204,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -247,7 +240,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.state_store.get_state_for_event(event_id) + state = await self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6dedaaff8d..4426967f88 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester @@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - @defer.inlineCallbacks - def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account Args: @@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler): identity_server_supports_unbinding = True # Retrieve the 3PIDs this user has bound to an identity server - threepids = yield self.store.user_get_bound_threepids(user_id) + threepids = await self.store.user_get_bound_threepids(user_id) for threepid in threepids: try: - result = yield self._identity_handler.try_unbind_threepid( + result = await self._identity_handler.try_unbind_threepid( user_id, { "medium": threepid["medium"], @@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler): # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - yield self.store.user_delete_threepid( + await self.store.user_delete_threepid( user_id, threepid["medium"], threepid["address"] ) # Remove all 3PIDs this user has bound to the homeserver - yield self.store.user_delete_threepids(user_id) + await self.store.user_delete_threepids(user_id) # delete any devices belonging to the user, which will also # delete corresponding access tokens. - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # then delete any remaining access tokens which weren't associated with # a device. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) - yield self.store.user_set_password_hash(user_id, None) + await self.store.user_set_password_hash(user_id, None) # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) - yield self.store.add_user_pending_deactivation(user_id) + await self.store.add_user_pending_deactivation(user_id) # delete from user directory - yield self.user_directory_handler.handle_user_deactivated(user_id) + await self.user_directory_handler.handle_user_deactivated(user_id) # Mark the user as erased, if they asked for that if erase_data: logger.info("Marking %s as erased", user_id) - yield self.store.mark_user_erased(user_id) + await self.store.mark_user_erased(user_id) # Now start the process that goes through that list and # parts users from rooms (if it isn't already running) @@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler): # Reject all pending invites for the user, so that the user doesn't show up in the # "invited" section of rooms' members list. - yield self._reject_pending_invites_for_user(user_id) + await self._reject_pending_invites_for_user(user_id) # Remove all information on the user from the account_validity table. if self._account_validity_enabled: - yield self.store.delete_account_validity_for_user(user_id) + await self.store.delete_account_validity_for_user(user_id) # Mark the user as deactivated. - yield self.store.set_user_deactivated_status(user_id, True) + await self.store.set_user_deactivated_status(user_id, True) return identity_server_supports_unbinding - @defer.inlineCallbacks - def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id): """Reject pending invites addressed to a given user ID. Args: user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = yield self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_user(user_id) for room in pending_invites: try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room.room_id, @@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler): if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - @defer.inlineCallbacks - def _user_parter_loop(self): + async def _user_parter_loop(self): """Loop that parts deactivated users from rooms Returns: @@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler): logger.info("Starting user parter") try: while True: - user_id = yield self.store.get_user_pending_deactivation() + user_id = await self.store.get_user_pending_deactivation() if user_id is None: break logger.info("User parter parting %r", user_id) - yield self._part_user(user_id) - yield self.store.del_user_pending_deactivation(user_id) + await self._part_user(user_id) + await self.store.del_user_pending_deactivation(user_id) logger.info("User parter finished parting %r", user_id) logger.info("User parter finished: stopping") finally: self._user_parter_running = False - @defer.inlineCallbacks - def _part_user(self, user_id): + async def _part_user(self, user_id): """Causes the given user_id to leave all the rooms they're joined to Returns: @@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler): """ user = UserID.from_string(user_id) - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) for room_id in rooms_for_user: logger.info("User parter parting %r from %r", user_id, room_id) try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room_id, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 57a10daefd..2d889364d4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,6 +264,7 @@ class E2eKeysHandler(object): return ret + @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -283,14 +284,32 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 - return defer.succeed( - { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } - ) + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6fb453ce60..72a0febc2b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from typing import Dict, Iterable, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import six from six import iteritems, itervalues @@ -63,6 +63,7 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room @@ -163,8 +164,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -174,17 +174,15 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id event_id = pdu.event_id - logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) + logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -228,7 +226,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -243,12 +241,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -268,7 +266,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -276,13 +274,19 @@ class FederationHandler(BaseHandler): len(missing_prevs), ) - yield self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -290,14 +294,6 @@ class FederationHandler(BaseHandler): room_id, event_id, ) - elif missing_prevs: - logger.info( - "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, - event_id, - len(missing_prevs), - shortstr(missing_prevs), - ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -342,12 +338,18 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -361,17 +363,14 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, - event_id, - p, + "Requesting state at missing prev_event %s", event_id, ) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = yield self._get_state_for_room( + (remote_state, _,) = await self._get_state_for_room( origin, room_id, p, include_event_in_state=True ) @@ -383,8 +382,8 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - room_version = yield self.store.get_room_version(room_id) - state_map = yield resolve_events_with_store( + room_version = await self.store.get_room_version(room_id) + state_map = await resolve_events_with_store( room_id, room_version, state_maps, @@ -397,10 +396,10 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) @@ -420,10 +419,9 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=state) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -435,12 +433,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -504,7 +502,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -543,7 +541,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -555,29 +553,30 @@ class FederationHandler(BaseHandler): else: raise - @defer.inlineCallbacks - @log_function - def _get_state_for_room( - self, destination, room_id, event_id, include_event_in_state - ): + async def _get_state_for_room( + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, + ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. include_event_in_state: if true, the event itself will be included in the returned state event list. Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ ( state_event_ids, auth_event_ids, - ) = yield self.federation_client.get_room_state_ids( + ) = await self.federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -586,15 +585,15 @@ class FederationHandler(BaseHandler): if include_event_in_state: desired_events.add(event_id) - event_map = yield self._get_events_from_store_or_dest( + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, + "Failed to fetch missing state/auth events for %s %s", + event_id, failed_to_fetch, ) @@ -614,15 +613,11 @@ class FederationHandler(BaseHandler): return remote_state, auth_chain - @defer.inlineCallbacks - def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: """Fetch events from a remote destination, checking if we already have them. - Args: - destination (str) - room_id (str) - event_ids (Iterable[str]) - Persists any events we don't already have as outliers. If we fail to fetch any of the events, a warning will be logged, and the event @@ -630,10 +625,9 @@ class FederationHandler(BaseHandler): be in the given room. Returns: - Deferred[dict[str, EventBase]]: A deferred resolving to a map - from event_id to event + map from event_id to event """ - fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) missing_events = set(event_ids) - fetched_events.keys() @@ -644,14 +638,14 @@ class FederationHandler(BaseHandler): room_id, ) - yield self._get_events_and_persist( + await self._get_events_and_persist( destination=destination, room_id=room_id, events=missing_events ) # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (yield self.store.get_events(missing_events, allow_rejected=True)) + (await self.store.get_events(missing_events, allow_rejected=True)) ) # check for events which were in the wrong room. @@ -677,12 +671,14 @@ class FederationHandler(BaseHandler): bad_room_id, room_id, ) + del fetched_events[bad_event_id] return fetched_events - @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state): + async def _process_received_pdu( + self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + ): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -701,15 +697,15 @@ class FederationHandler(BaseHandler): logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = yield self._handle_new_event(origin, event, state=state) + context = await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if not room: try: - yield self.store.store_room( + await self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: @@ -722,11 +718,11 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = yield self.store.get_event( + prev_state = await self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -734,11 +730,10 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, room_id) + await self.user_joined_room(user, room_id) @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities): + async def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -755,7 +750,7 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - events = yield self.federation_client.backfill( + events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -770,7 +765,7 @@ class FederationHandler(BaseHandler): # self._sanity_check_event(ev) # Don't bother processing events we already have. - seen_events = yield self.store.have_events_in_timeline( + seen_events = await self.store.have_events_in_timeline( set(e.event_id for e in events) ) @@ -796,7 +791,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self._get_state_for_room( + state, auth = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id, @@ -843,7 +838,7 @@ class FederationHandler(BaseHandler): ) ) - yield self._handle_new_events(dest, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -859,16 +854,15 @@ class FederationHandler(BaseHandler): # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event(dest, event, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events - @defer.inlineCallbacks - def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -900,15 +894,17 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events(list(extremities)) + forward_events = await self.store.get_successor_events(list(extremities)) - extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + extremities_events = await self.store.get_events( + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. - filtered_extremities = yield filter_events_for_server( + filtered_extremities = await filter_events_for_server( self.storage, self.server_name, list(extremities_events.values()), @@ -938,7 +934,7 @@ class FederationHandler(BaseHandler): # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = yield self.state_handler.get_current_state(room_id) + curr_state = await self.state_handler.get_current_state(room_id) def get_domains_from_state(state): """Get joined domains from state @@ -977,12 +973,11 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - @defer.inlineCallbacks - def try_backfill(domains): + async def try_backfill(domains): # TODO: Should we try multiple of these at a time? for dom in domains: try: - yield self.backfill( + await self.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1013,7 +1008,7 @@ class FederationHandler(BaseHandler): return False - success = yield try_backfill(likely_domains) + success = await try_backfill(likely_domains) if success: return True @@ -1027,7 +1022,7 @@ class FederationHandler(BaseHandler): logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = yield make_deferred_yieldable( + states = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1037,7 +1032,7 @@ class FederationHandler(BaseHandler): # event_ids. states = dict(zip(event_ids, [s.state for s in states])) - state_map = yield self.store.get_events( + state_map = await self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False, ) @@ -1053,7 +1048,7 @@ class FederationHandler(BaseHandler): for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill( + success = await try_backfill( [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: @@ -1063,8 +1058,7 @@ class FederationHandler(BaseHandler): return False - @defer.inlineCallbacks - def _get_events_and_persist( + async def _get_events_and_persist( self, destination: str, room_id: str, events: Iterable[str] ): """Fetch the given events from a server, and persist them as outliers. @@ -1072,7 +1066,7 @@ class FederationHandler(BaseHandler): Logs a warning if we can't find the given event. """ - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) event_infos = [] @@ -1108,9 +1102,9 @@ class FederationHandler(BaseHandler): e, ) - yield concurrently_execute(get_event, events, 5) + await concurrently_execute(get_event, events, 5) - yield self._handle_new_events( + await self._handle_new_events( destination, event_infos, ) @@ -1253,7 +1247,7 @@ class FederationHandler(BaseHandler): # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = yield self.store.get_room_predecessor(room_id) - if not predecessor: + if not predecessor or not isinstance(predecessor.get("room_id"), str): return old_room_id = predecessor["room_id"] logger.debug( @@ -1281,8 +1275,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1298,7 +1291,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1428,7 +1421,7 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1496,7 +1489,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave", content=content, + target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1937,7 +1930,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2346,12 +2339,12 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2635,7 +2628,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2683,7 +2676,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None @@ -2857,7 +2850,7 @@ class FederationHandler(BaseHandler): room_id=room_id, user_id=user.to_string(), change="joined" ) else: - return user_joined_room(self.distributor, user, room_id) + return defer.succeed(user_joined_room(self.distributor, user, room_id)) @defer.inlineCallbacks def get_room_complexity(self, remote_room_hosts, room_id): diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..44ec3e66ae 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) - @defer.inlineCallbacks - def _snapshot_all_rooms( + async def _snapshot_all_rooms( self, user_id=None, pagin_config=None, @@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=memberships ) @@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( + presence, _ = await presence_stream.get_pagination_rows( user, pagination_config.get_source_config("presence"), None ) receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( + receipt, _ = await receipt_stream.get_pagination_rows( user, pagination_config.get_source_config("receipt"), None ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = yield self.store.get_account_data_for_user( + account_data, account_data_by_room = await self.store.get_account_data_for_user( user_id ) - public_room_ids = yield self.store.get_public_room_ids() + public_room_ids = await self.store.get_public_room_ids() limit = pagin_config.limit if limit is None: limit = 10 - @defer.inlineCallbacks - def handle_room(event): + async def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["inviter"] = event.sender - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = yield self._event_serializer.serialize_event( + invite_event = await self.store.get_event(event.event_id) + d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, as_client_event ) @@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield make_deferred_yieldable( + (messages, token), current_state = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages ) @@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event ) ), @@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler): "end": end_token.to_string(), } - d["state"] = yield self._event_serializer.serialize_events( + d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler): except Exception: logger.exception("Failed to get snapshot") - yield concurrently_execute(handle_room, room_list, 10) + await concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): @@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync(self, requester, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. @@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") user_id = requester.user.to_string() - membership, member_event_id = yield self._check_in_room_or_world_readable( + membership, member_event_id = await self._check_in_room_or_world_readable( room_id, user_id ) is_peeking = member_event_id is None if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( + result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) - account_data = yield self.store.get_account_data_for_room(user_id, room_id) + account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): account_data_events.append({"type": account_data_type, "content": content}) @@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler): return result - @defer.inlineCallbacks - def _room_initial_sync_parted( + async def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.state_store.get_state_for_events([member_event_id]) + room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event(member_event_id) + stream_token = await self.store.get_stream_token_for_event(member_event_id) - messages, token = yield self.store.get_recent_events_for_room( + messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), }, "state": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( room_state.values(), time_now ) ), @@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler): "receipts": [], } - @defer.inlineCallbacks - def _room_initial_sync_joined( + async def _room_initial_sync_joined( self, user_id, room_id, pagin_config, membership, is_peeking ): - current_state = yield self.state.get_current_state(room_id=room_id) + current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() - state = yield self._event_serializer.serialize_events( + state = await self._event_serializer.serialize_events( current_state.values(), time_now ) - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: @@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() - @defer.inlineCallbacks - def get_presence(): + async def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: return [] - states = yield presence_handler.get_states( + states = await presence_handler.get_states( [m.user_id for m in room_members], as_event=True ) return states - @defer.inlineCallbacks - def get_receipts(): - receipts = yield self.store.get_linearized_receipts_for_room( + async def get_receipts(): + receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] return receipts - presence, receipts, (messages, token) = yield make_deferred_yieldable( + presence, receipts, (messages, token) = await make_deferred_yieldable( defer.gatherResults( [ run_in_background(get_presence), @@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler): ).addErrback(unwrapFirstError) ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), @@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): + async def _check_in_room_or_world_readable(self, room_id, user_id): try: # check_user_was_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + member_event = await self.auth.check_user_was_in_room(room_id, user_id) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state_handler.get_current_state( + visibility = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54fa216d83..4ad752205f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -514,7 +515,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -664,7 +665,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -875,7 +876,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -913,7 +914,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() state_to_include_ids = [ e_id @@ -952,7 +953,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -966,7 +967,7 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -988,7 +989,7 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8514ddc600..00a6afc963 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -280,8 +280,7 @@ class PaginationHandler(object): await self.storage.purge_events.purge_room(room_id) - @defer.inlineCallbacks - def get_messages( + async def get_messages( self, requester, room_id=None, @@ -307,7 +306,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_pagination() + await self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -319,11 +318,11 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (yield self.pagination_lock.read(room_id)): + with (await self.pagination_lock.read(room_id)): ( membership, member_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -331,7 +330,7 @@ class PaginationHandler(object): if room_token.topological: max_topo = room_token.topological else: - max_topo = yield self.store.get_max_topological_token( + max_topo = await self.store.get_max_topological_token( room_id, room_token.stream ) @@ -339,18 +338,18 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. - leave_token = yield self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) leave_token = RoomStreamToken.parse(leave_token) if leave_token.topological < max_topo: source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) - events, next_key = yield self.store.paginate_room_events( + events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=source_config.from_key, to_key=source_config.to_key, @@ -365,7 +364,7 @@ class PaginationHandler(object): if event_filter: events = event_filter.filter(events) - events = yield filter_events_for_client( + events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) ) @@ -385,19 +384,19 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) state = state.values() time_now = self.clock.time_msec() chunk = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event ) ), @@ -406,7 +405,7 @@ class PaginationHandler(object): } if state: - chunk["state"] = yield self._event_serializer.serialize_events( + chunk["state"] = await self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 1e5a4613c9..f9579d69ee 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler): be found to be in any room the server is in, and therefore the query is denied. """ + # Implementation of MSC1301: don't allow looking up profiles if the # requester isn't in the same room as the target. We expect requester to # be None when this function is called outside of a profile query, e.g. # when building a membership event. In this case, we must allow the # lookup. - if not self.hs.config.require_auth_for_profile_requests or not requester: + if ( + not self.hs.config.limit_profile_requests_to_users_who_share_rooms + or not requester + ): return # Always allow the user to query their own profile. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 60b8bbc7a5..89c9118b26 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids() # update any aliases yield self._move_aliases_to_new_room( @@ -1011,15 +1011,3 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - events, next_key = yield self.store.paginate_room_events( - room_id=key, - from_key=config.from_key, - to_key=config.to_key, - direction=config.direction, - limit=config.limit, - ) - - return (events, next_key) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61..44c5e3239c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,7 @@ class RoomMemberHandler(object): requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +601,7 @@ class RoomMemberHandler(object): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 56ed262a1f..ef750d1497 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64 from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -37,6 +37,7 @@ class SearchHandler(BaseHandler): self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state + self.auth = hs.get_auth() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -53,23 +54,38 @@ class SearchHandler(BaseHandler): room_id (str): id of the room to search through. Returns: - Deferred[iterable[unicode]]: predecessor room ids + Deferred[iterable[str]]: predecessor room ids """ historical_room_ids = [] - while True: - predecessor = yield self.store.get_room_predecessor(room_id) + # The initial room must have been known for us to get this far + predecessor = yield self.store.get_room_predecessor(room_id) - # If no predecessor, assume we've hit a dead end + while True: if not predecessor: + # We have reached the end of the chain of predecessors + break + + if not isinstance(predecessor.get("room_id"), str): + # This predecessor object is malformed. Exit here + break + + predecessor_room_id = predecessor["room_id"] + + # Don't add it to the list until we have checked that we are in the room + try: + next_predecessor_room = yield self.store.get_room_predecessor( + predecessor_room_id + ) + except NotFoundError: + # The predecessor is not a known room, so we are done here break - # Add predecessor's room ID - historical_room_ids.append(predecessor["room_id"]) + historical_room_ids.append(predecessor_room_id) - # Scan through the old room for further predecessors - room_id = predecessor["room_id"] + # And repeat + predecessor = next_predecessor_room return historical_room_ids diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6f78454322..b635c339ed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -317,6 +317,3 @@ class TypingNotificationEventSource(object): def get_current_key(self): return self.get_typing_handler()._latest_room_serial - - def get_pagination_rows(self, user, pagination_config, key): - return [], pagination_config.from_key diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 03934956f4..c0b9384189 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -171,7 +171,7 @@ class LogProducer(object): def stopProducing(self): self._paused = True - self._buffer = None + self._buffer = deque() def resumeProducing(self): self._paused = False diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..33b322209d 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -404,6 +405,9 @@ class LoggingContext(object): """ current = get_thread_resource_usage() + # Indicate to mypy that we know that self.usage_start is None. + assert self.usage_start is not None + utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -612,7 +616,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +629,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7881780760..7d9f5a38d9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -304,7 +304,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f277aeb131..8ad0bf5936 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -80,9 +80,11 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if "data" in pusherdict and "brand" in pusherdict["data"]: - app_name = pusherdict["data"]["brand"] - else: - app_name = self.config.email_app_name + data = pusherdict["data"] - return app_name + if isinstance(data, dict): + brand = data.get("brand") + if isinstance(brand, str): + return brand + + return self.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0f6992202d..b9dca5bc63 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -232,7 +232,6 @@ class PusherPool: Deferred """ pushers = yield self.store.get_all_pushers() - logger.info("Starting %d pushers", len(pushers)) # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -245,7 +244,7 @@ class PusherPool: """Start the given pusher Args: - pusherdict (dict): + pusherdict (dict): dict with the values pulled from the db table Returns: Deferred[EmailPusher|HttpPusher] @@ -254,7 +253,8 @@ class PusherPool: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: logger.warning( - "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", + "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", + pusherdict["id"], pusherdict.get("user_name"), pusherdict.get("app_id"), pusherdict.get("pushkey"), @@ -262,7 +262,9 @@ class PusherPool: ) return except Exception: - logger.exception("Couldn't start a pusher: caught Exception") + logger.exception( + "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + ) return if not p: diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9af4e7e173..49a3251372 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = EventContext.deserialize(self.store, event_payload["context"]) + context = EventContext.deserialize( + self.storage, event_payload["context"] + ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9bafd60b14..84b92f16ad 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 0791866f55..6f6b7aed6e 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) +ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - allowed_keys = [ - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", - ] - - for p in pushers: - for k, v in list(p.items()): - if k not in allowed_keys: - del p[k] - - return 200, {"pushers": pushers} + filtered_pushers = list( + {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers + ) + + return 200, {"pushers": filtered_pushers} def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/server.py b/synapse/server.py index 2db3dab221..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -34,6 +33,7 @@ from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory @@ -132,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -209,16 +208,18 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. + config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] self.start_time = None @@ -237,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -274,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -423,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 0e75e94c6f..5accc071ab 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -655,7 +656,7 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..d20df5f076 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.data_stores.state import StateGroupDataStore +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + self.state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 85cfa16850..0613b49f4a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - # Compatible method of performing an upsert - sql = "SELECT stream_id FROM device_max_stream_id" - - txn.execute(sql) - rows = txn.fetchone() - if rows: - db_stream_id = rows[0] - if db_stream_id < stream_id: - # Insert the new stream_id - sql = "UPDATE device_max_stream_id SET stream_id = ?" - else: - # No rows, perform an insert - sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)" - - txn.execute(sql, (stream_id,)) + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" + txn.execute(sql, (stream_id, stream_id)) local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 998bba1aad..58f35d7f56 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1757,163 +1757,6 @@ class EventsStore( return state_groups - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = set( - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): - """Fetch the previous groups of the given state groups. - - Args: - state_groups (Iterable[int]) - - Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. - """ - - rows = yield self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f07309ef09..6b03233262 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,8 +15,7 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json @@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: logger.warning( @@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore): dataJson, e.args[0], ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") + continue - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644 index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644 index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 4ad2929f32..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -975,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index bad33291e7..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql index c265fd20e2..91d21b2921 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql @@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); INSERT INTO stats_stream_pos (stream_id) VALUES (0); INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..bbd3f18604 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -0,0 +1,13 @@ +# Building full schema dumps + +These schemas need to be made from a database that has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`. + +``` +./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ +``` diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 4eec2fae5e..47ebb8a214 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9ef7b48c74..0dc39f139c 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -17,8 +17,7 @@ import logging from collections import namedtuple from typing import Iterable, Tuple -from six import iteritems, itervalues -from six.moves import range +from six import iteritems from twisted.internet import defer @@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter -from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -55,207 +52,14 @@ class _GetStateGroupDelta( return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - @defer.inlineCallbacks def get_room_version(self, room_id): """Get the room_version of a given room @@ -278,7 +82,7 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. + """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: @@ -291,14 +95,22 @@ class StateGroupWorkerStore( * room_id (str): The room ID of the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room + None if a predecessor key is not found, or is not a dictionary. + Raises: - NotFoundError if the room is unknown + NotFoundError if the given room is unknown """ # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - # Return predecessor if present - return create_event.content.get("predecessor", None) + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, dict): + return None + + return predecessor @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -318,7 +130,7 @@ class StateGroupWorkerStore( # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) + raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return create_event = yield self.get_event(create_id) @@ -423,229 +235,6 @@ class StateGroupWorkerStore( return event.content.get("canonical_alias") - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): return self.db.simple_select_one_onecol( @@ -676,329 +265,6 @@ class StateGroupWorkerStore( return {row["event_id"]: row["state_group"] for row in rows} - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): """Check if the state groups are referenced by events. @@ -1023,22 +289,14 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): +class MainStateBackgroundUpdateStore(SQLBaseStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", @@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["state_group"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644 index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644 index 0000000000..e8edaf9f7b --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql index e85699e82e..e85699e82e 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql +++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql index 33980d02f0..33980d02f0 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py +++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644 index 0000000000..d53695f238 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index cbc74cd302..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -62,6 +60,10 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + if self._is_in_memory: # In memory databases need to be rebuilt each time. Ideally we'd # reuse the same connection as we do when starting up, but that diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa03ca9ff7..1ed44925fc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -183,7 +183,7 @@ class EventsPersistenceStorage(object): # so we use separate variables here even though they point to the same # store for now. self.main_store = stores.main - self.state_store = stores.main + self.state_store = stores.state self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..e70026b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -18,6 +18,7 @@ import imp import logging import os import re +from collections import Counter import attr @@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) @@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config): if user_version != SCHEMA_VERSION: # If we don't pass in a config file then we are expecting to # have already upgraded the DB. - raise UpgradeDatabaseException("Database needs to be upgraded") + raise UpgradeDatabaseException( + "Expected database schema version %i but got %i" + % (SCHEMA_VERSION, user_version) + ) else: _upgrade_existing_database( cur, @@ -313,6 +316,9 @@ def _upgrade_existing_database( ) ) + # Used to check if we have any duplicate file names + file_name_counter = Counter() + # Now find which directories have anything of interest. directory_entries = [] for directory in directories: @@ -323,6 +329,9 @@ def _upgrade_existing_database( _DirectoryListing(file_name, os.path.join(directory, file_name)) for file_name in file_names ) + + for file_name in file_names: + file_name_counter[file_name] += 1 except FileNotFoundError: # Data stores can have empty entries for a given version delta. pass @@ -331,6 +340,17 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) + duplicates = set( + file_name for file_name, count in file_name_counter.items() if count > 1 + ) + if duplicates: + # We don't support using the same file name in the same delta version. + raise PrepareDatabaseException( + "Found multiple delta files with the same name in v%d: %s", + v, + duplicates, + ) + # We sort to ensure that we apply the delta files in a consistent # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index a368182034..d6a7bd7834 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -58,7 +58,7 @@ class PurgeEventsStorage(object): sg_to_delete = yield self._find_unreferenced_groups(state_groups) - yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) @defer.inlineCallbacks def _find_unreferenced_groups(self, state_groups): @@ -102,7 +102,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.main.get_previous_state_groups(current_search) + edges = yield self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3735846899..cbeb586014 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -342,7 +342,7 @@ class StateGroupStorage(object): (prev_group, delta_ids) """ - return self.stores.main.get_state_group_delta(state_group) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -362,7 +362,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -423,7 +423,7 @@ class StateGroupStorage(object): dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_groups_from_groups(groups, state_filter) + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -439,7 +439,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -476,7 +476,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -532,7 +532,7 @@ class StateGroupStorage(object): Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -552,6 +552,6 @@ class StateGroupStorage(object): Returns: Deferred[int]: The state group ID """ - return self.stores.main.store_state_group( + return self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index b91fb2db7b..fcd2aaa9c9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + from twisted.internet import defer from synapse.handlers.account_data import AccountDataEventSource @@ -35,7 +37,7 @@ class EventSources(object): def __init__(self, hs): self.sources = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } + } # type: Dict[str, Any] self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c3..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() |