diff options
Diffstat (limited to 'synapse')
39 files changed, 1177 insertions, 697 deletions
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..51a909419f 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -229,14 +228,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 02b900f382..e82e0f11e3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -143,8 +142,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -159,10 +156,8 @@ def start(config_options): ps = AppserviceServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dadb487d5f..3edfe19567 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -181,14 +180,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = ClientReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d110599a35..d0ddbe38fc 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -180,14 +179,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = EventCreatorServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 418c086254..311523e0ed 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -162,14 +161,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FederationReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f24920a7d6..83c436229c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -174,8 +173,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" @@ -190,10 +187,8 @@ def start(config_options): ss = FederationSenderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e647459d0e..30e435eead 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -234,14 +233,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FrontendProxyServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..b8661457e2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree @@ -328,15 +328,10 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) @@ -519,8 +514,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.database_engine.module.__name__ - stats["database_server_version"] = hs.database_engine.server_version + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2c6dd3ef02..4c80f257e2 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -157,14 +156,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = MediaRepositoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index dd52a9fc2d..09e639040a 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -203,14 +202,10 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True - database_engine = create_engine(config.database_config) - ps = PusherServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 288ee64b42..dd2132e608 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -437,14 +436,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = SynchrotronServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, application_service_handler=SynchrotronApplicationService(), ) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index c01fb34a9b..1257098f92 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -200,8 +199,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -216,10 +213,8 @@ def start(config_options): ss = UserDirectoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a1..b91414aa35 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ class SAML2Config(Config): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) - - def _default_saml_config_dict(self): + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ class SAML2Config(Config): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace - - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ec3243b27b..c940b84470 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 709449c9e3..af652a7659 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,8 +18,6 @@ import copy import itertools import logging -from six.moves import range - from prometheus_client import Counter from twisted.internet import defer @@ -39,7 +37,7 @@ from synapse.api.room_versions import ( ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache @@ -310,19 +308,12 @@ class FederationClient(FederationBase): return signed_pdu @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) """ result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id @@ -331,86 +322,12 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) - - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (list) - - Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. - """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - - failed_to_fetch = set() - - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) - - if not missing_events: - return signed_events, failed_to_fetch - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) - - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - signed_events.append(result) - batch.discard(result.event_id) - - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) - - return signed_events, failed_to_fetch + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function @@ -609,13 +526,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + content = yield self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) @@ -683,6 +594,44 @@ class FederationClient(FederationBase): return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks + def _do_send_join(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_join_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() + + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_join_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] + + @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): room_version = yield self.store.get_room_version(room_id) @@ -791,18 +740,50 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( + content = yield self._do_send_leave(destination, pdu) + + logger.debug("Got content: %s", content) + return None + + return self._try_destination_list("send_leave", destinations, send_request) + + @defer.inlineCallbacks + def _do_send_leave(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - logger.debug("Got content: %s", content) - return None + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() - return self._try_destination_list("send_leave", destinations, send_request) + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_leave_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] def get_public_rooms( self, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 84d4eca041..d7ce333822 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -384,15 +384,10 @@ class FederationServer(FederationBase): res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - return ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + } async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) @@ -419,7 +414,7 @@ class FederationServer(FederationBase): pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) - return 200, {} + return {} async def on_event_auth(self, origin, room_id, event_id): with (await self._server_linearizer.queue((origin, room_id))): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 46dba84cac..198257414b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -243,7 +243,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_join(self, destination, room_id, event_id, content): + def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -254,7 +254,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_leave(self, destination, room_id, event_id, content): + def send_join_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, path=path, data=content + ) + + return response + + @defer.inlineCallbacks + @log_function + def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -272,6 +283,24 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def send_leave_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, + ) + + return response + + @defer.inlineCallbacks + @log_function def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fefc789c85..b4cbf23394 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationSendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet): return await self.handler.on_event_auth(origin, context, event_id) -class FederationSendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PREFIX = FEDERATION_V2_PREFIX + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content @@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = ( FederationMakeJoinServlet, FederationMakeLeaveServlet, FederationEventServlet, - FederationSendJoinServlet, - FederationSendLeaveServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationQueryAuthServlet, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 2d7e6df6e4..20ec1ca01b 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,15 +21,14 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( @@ -41,7 +38,7 @@ class AccountDataEventSource(object): ( account_data, room_account_data, - ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -54,6 +51,5 @@ class AccountDataEventSource(object): return results, current_stream_id - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): return [], config.to_id diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 57a10daefd..2d889364d4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,6 +264,7 @@ class E2eKeysHandler(object): return ret + @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -283,14 +284,32 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 - return defer.succeed( - { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } - ) + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bc26921768..2ea69c5468 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from typing import Dict, Iterable, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import six from six import iteritems, itervalues @@ -63,8 +63,9 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -164,8 +165,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -175,17 +175,15 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id event_id = pdu.event_id - logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) + logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -229,7 +227,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -245,12 +243,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -270,7 +268,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -278,13 +276,19 @@ class FederationHandler(BaseHandler): len(missing_prevs), ) - yield self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -292,14 +296,6 @@ class FederationHandler(BaseHandler): room_id, event_id, ) - elif missing_prevs: - logger.info( - "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, - event_id, - len(missing_prevs), - shortstr(missing_prevs), - ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -344,13 +340,19 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -364,13 +366,10 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, - event_id, - p, + "Requesting state at missing prev_event %s", event_id, ) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) with nested_logging_context(p): # note that if any of the missing prevs share missing state or @@ -379,24 +378,10 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self.federation_client.get_state_for_room( - origin, room_id, p + ) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - # we want the state *after* p; get_state_for_room returns the - # state *before* p. - remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True - ) - - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - # XXX hrm I'm not convinced that duplicate events will compare # for equality, so I'm not sure this does what the author # hoped. @@ -410,7 +395,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - state_map = yield resolve_events_with_store( + state_map = await resolve_events_with_store( room_version, state_maps, event_map, @@ -422,10 +407,10 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) @@ -446,12 +431,11 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( + await self._process_received_pdu( origin, pdu, state=state, auth_chain=auth_chain ) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -463,12 +447,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -532,7 +516,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -571,7 +555,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -583,8 +567,144 @@ class FederationHandler(BaseHandler): else: raise - @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state, auth_chain): + async def _get_state_for_room( + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. + + Returns: + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + + event_map = await self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state(): + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + auth_chain.sort(key=lambda e: e.depth) + + return remote_state, auth_chain + + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: + """Fetch events from a remote destination, checking if we already have them. + + Args: + destination + room_id + event_ids + + If we fail to fetch any of the events, a warning will be logged, and the event + will be omitted from the result. Likewise, any events which turn out not to + be in the given room. + + Returns: + map from event_id to event + """ + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if missing_events: + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + room_id, + ) + + room_version = await self.store.get_room_version(room_id) + + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] + + res = await make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) + ) + + for success, result in res: + if success and result: + fetched_events[result.event_id] = result + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = list( + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ) + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + del fetched_events[bad_event_id] + + return fetched_events + + async def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -599,7 +719,7 @@ class FederationHandler(BaseHandler): if auth_chain: event_ids |= {e.event_id for e in auth_chain} - seen_ids = yield self.store.have_seen_events(event_ids) + seen_ids = await self.store.have_seen_events(event_ids) if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication @@ -626,18 +746,18 @@ class FederationHandler(BaseHandler): event_id, [e.event.event_id for e in event_infos], ) - yield self._handle_new_events(origin, event_infos) + await self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event(origin, event, state=state) + context = await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if not room: try: - yield self.store.store_room( + await self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: @@ -650,11 +770,11 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids(self.store) prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = yield self.store.get_event( + prev_state = await self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -662,11 +782,10 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, room_id) + await self.user_joined_room(user, room_id) @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities): + async def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -683,9 +802,9 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) - events = yield self.federation_client.backfill( + events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -700,7 +819,7 @@ class FederationHandler(BaseHandler): # self._sanity_check_event(ev) # Don't bother processing events we already have. - seen_events = yield self.store.have_events_in_timeline( + seen_events = await self.store.have_events_in_timeline( set(e.event_id for e in events) ) @@ -723,7 +842,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( + state, auth = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) @@ -748,7 +867,7 @@ class FederationHandler(BaseHandler): # We repeatedly do this until we stop finding new auth events. while missing_auth - failed_to_fetch: logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + ret_events = await self.store.get_events(missing_auth - failed_to_fetch) auth_events.update(ret_events) required_auth.update( @@ -762,7 +881,7 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch, ) - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -789,7 +908,7 @@ class FederationHandler(BaseHandler): failed_to_fetch = missing_auth - set(auth_events) - seen_events = yield self.store.have_seen_events( + seen_events = await self.store.have_seen_events( set(auth_events.keys()) | set(state_events.keys()) ) @@ -851,7 +970,7 @@ class FederationHandler(BaseHandler): ) ) - yield self._handle_new_events(dest, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -867,16 +986,15 @@ class FederationHandler(BaseHandler): # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event(dest, event, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events - @defer.inlineCallbacks - def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -908,15 +1026,17 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events(list(extremities)) + forward_events = await self.store.get_successor_events(list(extremities)) - extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + extremities_events = await self.store.get_events( + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. - filtered_extremities = yield filter_events_for_server( + filtered_extremities = await filter_events_for_server( self.storage, self.server_name, list(extremities_events.values()), @@ -946,7 +1066,7 @@ class FederationHandler(BaseHandler): # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = yield self.state_handler.get_current_state(room_id) + curr_state = await self.state_handler.get_current_state(room_id) def get_domains_from_state(state): """Get joined domains from state @@ -985,12 +1105,11 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - @defer.inlineCallbacks - def try_backfill(domains): + async def try_backfill(domains): # TODO: Should we try multiple of these at a time? for dom in domains: try: - yield self.backfill( + await self.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1021,7 +1140,7 @@ class FederationHandler(BaseHandler): return False - success = yield try_backfill(likely_domains) + success = await try_backfill(likely_domains) if success: return True @@ -1035,7 +1154,7 @@ class FederationHandler(BaseHandler): logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = yield make_deferred_yieldable( + states = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1045,7 +1164,7 @@ class FederationHandler(BaseHandler): # event_ids. states = dict(zip(event_ids, [s.state for s in states])) - state_map = yield self.store.get_events( + state_map = await self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False, ) @@ -1061,7 +1180,7 @@ class FederationHandler(BaseHandler): for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill( + success = await try_backfill( [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: @@ -1210,7 +1329,7 @@ class FederationHandler(BaseHandler): # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = yield self.store.get_room_predecessor(room_id) - if not predecessor: + if not predecessor or not isinstance(predecessor.get("room_id"), str): return old_room_id = predecessor["room_id"] logger.debug( @@ -1238,8 +1357,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1255,7 +1373,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1453,7 +1571,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave", content=content, + target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -2814,7 +2932,7 @@ class FederationHandler(BaseHandler): room_id=room_id, user_id=user.to_string(), change="joined" ) else: - return user_joined_room(self.distributor, user, room_id) + return defer.succeed(user_joined_room(self.distributor, user, room_id)) @defer.inlineCallbacks def get_room_complexity(self, remote_room_hosts, room_id): diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..44ec3e66ae 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) - @defer.inlineCallbacks - def _snapshot_all_rooms( + async def _snapshot_all_rooms( self, user_id=None, pagin_config=None, @@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=memberships ) @@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( + presence, _ = await presence_stream.get_pagination_rows( user, pagination_config.get_source_config("presence"), None ) receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( + receipt, _ = await receipt_stream.get_pagination_rows( user, pagination_config.get_source_config("receipt"), None ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = yield self.store.get_account_data_for_user( + account_data, account_data_by_room = await self.store.get_account_data_for_user( user_id ) - public_room_ids = yield self.store.get_public_room_ids() + public_room_ids = await self.store.get_public_room_ids() limit = pagin_config.limit if limit is None: limit = 10 - @defer.inlineCallbacks - def handle_room(event): + async def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["inviter"] = event.sender - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = yield self._event_serializer.serialize_event( + invite_event = await self.store.get_event(event.event_id) + d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, as_client_event ) @@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield make_deferred_yieldable( + (messages, token), current_state = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages ) @@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event ) ), @@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler): "end": end_token.to_string(), } - d["state"] = yield self._event_serializer.serialize_events( + d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler): except Exception: logger.exception("Failed to get snapshot") - yield concurrently_execute(handle_room, room_list, 10) + await concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): @@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync(self, requester, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. @@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") user_id = requester.user.to_string() - membership, member_event_id = yield self._check_in_room_or_world_readable( + membership, member_event_id = await self._check_in_room_or_world_readable( room_id, user_id ) is_peeking = member_event_id is None if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( + result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) - account_data = yield self.store.get_account_data_for_room(user_id, room_id) + account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): account_data_events.append({"type": account_data_type, "content": content}) @@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler): return result - @defer.inlineCallbacks - def _room_initial_sync_parted( + async def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.state_store.get_state_for_events([member_event_id]) + room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event(member_event_id) + stream_token = await self.store.get_stream_token_for_event(member_event_id) - messages, token = yield self.store.get_recent_events_for_room( + messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), }, "state": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( room_state.values(), time_now ) ), @@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler): "receipts": [], } - @defer.inlineCallbacks - def _room_initial_sync_joined( + async def _room_initial_sync_joined( self, user_id, room_id, pagin_config, membership, is_peeking ): - current_state = yield self.state.get_current_state(room_id=room_id) + current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() - state = yield self._event_serializer.serialize_events( + state = await self._event_serializer.serialize_events( current_state.values(), time_now ) - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: @@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() - @defer.inlineCallbacks - def get_presence(): + async def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: return [] - states = yield presence_handler.get_states( + states = await presence_handler.get_states( [m.user_id for m in room_members], as_event=True ) return states - @defer.inlineCallbacks - def get_receipts(): - receipts = yield self.store.get_linearized_receipts_for_room( + async def get_receipts(): + receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] return receipts - presence, receipts, (messages, token) = yield make_deferred_yieldable( + presence, receipts, (messages, token) = await make_deferred_yieldable( defer.gatherResults( [ run_in_background(get_presence), @@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler): ).addErrback(unwrapFirstError) ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), @@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): + async def _check_in_room_or_world_readable(self, room_id, user_id): try: # check_user_was_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + member_event = await self.auth.check_user_was_in_room(room_id, user_id) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state_handler.get_current_state( + visibility = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54fa216d83..bf9add7fe2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -875,7 +876,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -952,7 +953,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8514ddc600..00a6afc963 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -280,8 +280,7 @@ class PaginationHandler(object): await self.storage.purge_events.purge_room(room_id) - @defer.inlineCallbacks - def get_messages( + async def get_messages( self, requester, room_id=None, @@ -307,7 +306,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_pagination() + await self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -319,11 +318,11 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (yield self.pagination_lock.read(room_id)): + with (await self.pagination_lock.read(room_id)): ( membership, member_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -331,7 +330,7 @@ class PaginationHandler(object): if room_token.topological: max_topo = room_token.topological else: - max_topo = yield self.store.get_max_topological_token( + max_topo = await self.store.get_max_topological_token( room_id, room_token.stream ) @@ -339,18 +338,18 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. - leave_token = yield self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) leave_token = RoomStreamToken.parse(leave_token) if leave_token.topological < max_topo: source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) - events, next_key = yield self.store.paginate_room_events( + events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=source_config.from_key, to_key=source_config.to_key, @@ -365,7 +364,7 @@ class PaginationHandler(object): if event_filter: events = event_filter.filter(events) - events = yield filter_events_for_client( + events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) ) @@ -385,19 +384,19 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) state = state.values() time_now = self.clock.time_msec() chunk = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event ) ), @@ -406,7 +405,7 @@ class PaginationHandler(object): } if state: - chunk["state"] = yield self._event_serializer.serialize_events( + chunk["state"] = await self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 56ed262a1f..ef750d1497 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64 from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -37,6 +37,7 @@ class SearchHandler(BaseHandler): self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state + self.auth = hs.get_auth() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -53,23 +54,38 @@ class SearchHandler(BaseHandler): room_id (str): id of the room to search through. Returns: - Deferred[iterable[unicode]]: predecessor room ids + Deferred[iterable[str]]: predecessor room ids """ historical_room_ids = [] - while True: - predecessor = yield self.store.get_room_predecessor(room_id) + # The initial room must have been known for us to get this far + predecessor = yield self.store.get_room_predecessor(room_id) - # If no predecessor, assume we've hit a dead end + while True: if not predecessor: + # We have reached the end of the chain of predecessors + break + + if not isinstance(predecessor.get("room_id"), str): + # This predecessor object is malformed. Exit here + break + + predecessor_room_id = predecessor["room_id"] + + # Don't add it to the list until we have checked that we are in the room + try: + next_predecessor_room = yield self.store.get_room_predecessor( + predecessor_room_id + ) + except NotFoundError: + # The predecessor is not a known room, so we are done here break - # Add predecessor's room ID - historical_room_ids.append(predecessor["room_id"]) + historical_room_ids.append(predecessor_room_id) - # Scan through the old room for further predecessors - room_id = predecessor["room_id"] + # And repeat + predecessor = next_predecessor_room return historical_room_ids diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..6747f29e6a 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/synapse/server.py b/synapse/server.py index 2db3dab221..5021068ce0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,6 +34,7 @@ from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory @@ -97,6 +98,7 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage +from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -209,16 +211,18 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. + config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] self.start_time = None @@ -229,6 +233,12 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() + self.database_engine = create_engine(config.database_config) + config.database_config.setdefault("args", {})[ + "cp_openfun" + ] = self.database_engine.on_new_connection + self.db_config = config.database_config + self.datastores = None # Other kwargs are explicit dependencies diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 139beef8ed..3e6d62eef1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -645,7 +646,7 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..add3037b69 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..bbd3f18604 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -0,0 +1,13 @@ +# Building full schema dumps + +These schemas need to be made from a database that has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`. + +``` +./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ +``` diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 4eec2fae5e..47ebb8a214 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9ef7b48c74..dcc6b43cdf 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -278,7 +278,7 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. + """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: @@ -291,14 +291,22 @@ class StateGroupWorkerStore( * room_id (str): The room ID of the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room + None if a predecessor key is not found, or is not a dictionary. + Raises: - NotFoundError if the room is unknown + NotFoundError if the given room is unknown """ # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - # Return predecessor if present - return create_event.content.get("predecessor", None) + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, dict): + return None + + return predecessor @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -318,7 +326,7 @@ class StateGroupWorkerStore( # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) + raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return create_event = yield self.get_event(create_id) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c3..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() |