diff options
29 files changed, 653 insertions, 416 deletions
diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature new file mode 100644 index 0000000000..ebea4a208d --- /dev/null +++ b/changelog.d/6411.feature @@ -0,0 +1 @@ +Allow custom SAML username mapping functinality through an external provider plugin. \ No newline at end of file diff --git a/changelog.d/6499.bugfix b/changelog.d/6499.bugfix new file mode 100644 index 0000000000..299feba0f8 --- /dev/null +++ b/changelog.d/6499.bugfix @@ -0,0 +1 @@ +Fix support for SQLite 3.7. diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc new file mode 100644 index 0000000000..255f45a9c3 --- /dev/null +++ b/changelog.d/6501.misc @@ -0,0 +1 @@ +Refactor get_events_from_store_or_dest to return a dict. diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal new file mode 100644 index 0000000000..0b72261d58 --- /dev/null +++ b/changelog.d/6502.removal @@ -0,0 +1 @@ +Remove redundant code from event authorisation implementation. diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc new file mode 100644 index 0000000000..e4e9a5a3d4 --- /dev/null +++ b/changelog.d/6503.misc @@ -0,0 +1 @@ +Move get_state methods into FederationHandler. diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc new file mode 100644 index 0000000000..3a75b2d9dd --- /dev/null +++ b/changelog.d/6505.misc @@ -0,0 +1 @@ +Make `make_deferred_yieldable` to work with async/await. diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc new file mode 100644 index 0000000000..99d7a70bcf --- /dev/null +++ b/changelog.d/6506.misc @@ -0,0 +1 @@ +Remove `SnapshotCache` in favour of `ResponseCache`. diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc new file mode 100644 index 0000000000..214f06539b --- /dev/null +++ b/changelog.d/6510.misc @@ -0,0 +1 @@ +Change phone home stats to not assume there is a single database and report information about the database used by the main data store. diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix new file mode 100644 index 0000000000..6dc1985c24 --- /dev/null +++ b/changelog.d/6514.bugfix @@ -0,0 +1 @@ +Fix race which occasionally caused deleted devices to reappear. diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md new file mode 100644 index 0000000000..92f2380488 --- /dev/null +++ b/docs/saml_mapping_providers.md @@ -0,0 +1,77 @@ +# SAML Mapping Providers + +A SAML mapping provider is a Python class (loaded via a Python module) that +works out how to map attributes of a SAML response object to Matrix-specific +user attributes. Details such as user ID localpart, displayname, and even avatar +URLs are all things that can be mapped from talking to a SSO service. + +As an example, a SSO service may return the email address +"john.smith@example.com" for a user, whereas Synapse will need to figure out how +to turn that into a displayname when creating a Matrix user for this individual. +It may choose `John Smith`, or `Smith, John [Example.com]` or any number of +variations. As each Synapse configuration may want something different, this is +where SAML mapping providers come into play. + +## Enabling Providers + +External mapping providers are provided to Synapse in the form of an external +Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere, +then tell Synapse where to look for the handler class by editing the +`saml2_config.user_mapping_provider.module` config option. + +`saml2_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +## Building a Custom Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `saml_response_to_user_attributes(self, saml_response, failures)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `mxid_localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `mxid_localpart` value, such as `john.doe1`. + - This method must return a dictionary, which will then be used by Synapse + to build a new user. The following keys are allowed: + * `mxid_localpart` - Required. The mxid localpart of the new user. + * `displayname` - The displayname of the new user. If not provided, will default to + the value of `mxid_localpart`. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `saml2_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract any option values + they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_saml_attributes(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A object resulting from a call to `parse_config`. + - Returns a tuple of two sets. The first set equates to the saml auth + response attributes that are required for the module to function, whereas + the second set consists of those attributes which can be used if available, + but are not necessary. + +## Synapse's Default Provider + +Synapse has a built-in SAML mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 10664ae8f7..4d44e631d1 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1250,33 +1250,58 @@ saml2_config: # #config_path: "CONFDIR/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..032010600a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -519,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.database_engine.module.__name__ - stats["database_server_version"] = hs.database_engine.server_version + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a1..b91414aa35 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ class SAML2Config(Config): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) - - def _default_saml_config_dict(self): + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ class SAML2Config(Config): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace - - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ec3243b27b..c940b84470 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 709449c9e3..d396e6564f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,8 +18,6 @@ import copy import itertools import logging -from six.moves import range - from prometheus_client import Counter from twisted.internet import defer @@ -39,7 +37,7 @@ from synapse.api.room_versions import ( ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache @@ -310,19 +308,12 @@ class FederationClient(FederationBase): return signed_pdu @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) """ result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id @@ -331,86 +322,12 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) - - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (list) - - Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. - """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - - failed_to_fetch = set() - - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) - - if not missing_events: - return signed_events, failed_to_fetch - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) - - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - signed_events.append(result) - batch.discard(result.event_id) - - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - return signed_events, failed_to_fetch + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bc26921768..c0dcf9abf8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,7 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -379,11 +379,9 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ) = yield self._get_state_for_room(origin, room_id, p) - # we want the state *after* p; get_state_for_room returns the + # we want the state *after* p; _get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( [origin], p, room_version, outlier=True @@ -584,6 +582,97 @@ class FederationHandler(BaseHandler): raise @defer.inlineCallbacks + @log_function + def _get_state_for_room(self, destination, room_id, event_id): + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = yield self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + event_map = yield self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch, + ) + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + + auth_chain.sort(key=lambda e: e.depth) + + return pdus, auth_chain + + @defer.inlineCallbacks + def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + """Fetch events from a remote destination, checking if we already have them. + + Args: + destination (str) + room_id (str) + event_ids (Iterable[str]) + + Returns: + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event + """ + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if not missing_events: + return fetched_events + + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + event_ids, + ) + + room_version = yield self.store.get_room_version(room_id) + + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] + + res = yield make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) + ) + for success, result in res: + if success and result: + fetched_events[result.event_id] = result + + return fetched_events + + @defer.inlineCallbacks def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -723,7 +812,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( + state, auth = yield self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..73c110a92b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) @defer.inlineCallbacks diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..6747f29e6a 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..add3037b69 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index da1529f6ea..998bba1aad 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1039,20 +1039,25 @@ class EventsStore( }, ) - @defer.inlineCallbacks - def _censor_redactions(self): + async def _censor_redactions(self): """Censors all redactions older than the configured period that haven't been censored yet. By censor we mean update the event_json table with the redacted event. - - Returns: - Deferred """ if self.hs.config.redaction_retention_period is None: return + if not ( + await self.db.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period # We fetch all redactions that: @@ -1074,15 +1079,15 @@ class EventsStore( LIMIT ? """ - rows = yield self.db.execute( + rows = await self.db.execute( "_censor_redactions_fetch", None, sql, before_ts, 100 ) updates = [] for redaction_id, event_id in rows: - redaction_event = yield self.get_event(redaction_id, allow_none=True) - original_event = yield self.get_event( + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( event_id, allow_rejected=True, allow_none=True ) @@ -1115,7 +1120,7 @@ class EventsStore( updatevalues={"have_censored": True}, ) - yield self.db.runInteraction("_update_censor_txn", _update_censor_txn) + await self.db.runInteraction("_update_censor_txn", _update_censor_txn) def _censor_event_txn(self, txn, event_id, pruned_json): """Censor an event by replacing its JSON in the event_json table with the diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index efee17b929..5177b71016 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -90,6 +90,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): "event_store_labels", self._event_store_labels ) + self.db.updates.register_background_index_update( + "redactions_have_censored_ts_idx", + index_name="redactions_have_censored_ts", + table="redactions", + columns=["received_ts"], + where_clause="NOT have_censored", + ) + @defer.inlineCallbacks def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql index fe51b02309..ea95db0ed7 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql @@ -14,4 +14,3 @@ */ ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; -CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql index 77a5eca499..49ce35d794 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql @@ -14,7 +14,9 @@ */ ALTER TABLE redactions ADD COLUMN received_ts BIGINT; -CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored; INSERT INTO background_updates (update_name, progress_json) VALUES ('redactions_received_ts', '{}'); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_have_censored_ts_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql new file mode 100644 index 0000000000..b7550f6f4e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS redactions_have_censored; diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index fc279340d4..bf674dd184 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) - # Force persisting to disk self.reactor.advance(200) @@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.get_success( self.store.db.simple_update( table="devices", - keyvalues={"user_id": user_id, "device_id": "device_id"}, + keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, desc="test_devices_last_seen_bg_update", ) @@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get nulls when querying result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, @@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get the correct result again result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, @@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", - "device_id": "device_id", + "device_id": device_id, "last_seen": 0, } ], @@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But we should still get the correct values for the device result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7..281b32c4b8 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f72425..0000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) |