diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2020-06-24 12:07:41 +0100 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2020-06-24 12:07:41 +0100 |
commit | a7d49db74fdc303bcd295db501644d54846f1fd5 (patch) | |
tree | ec564c03c6b642fb7ea9d830a26156bfd44f0460 /synapse | |
parent | Prevent M_USER_IN_USE from being raised by registration methods until after e... (diff) | |
parent | 1.15.0 (diff) | |
download | synapse-a7d49db74fdc303bcd295db501644d54846f1fd5.tar.xz |
Merge branch 'release-v1.15.0' of github.com:matrix-org/synapse into dinsic-release-v1.15.x
* 'release-v1.15.0' of github.com:matrix-org/synapse: (55 commits) 1.15.0 Fix some attributions Update CHANGES.md 1.15.0rc1 Revert "1.15.0rc1" 1.15.0rc1 Fix bug in account data replication stream. (#7656) Convert the registration handler to async/await. (#7649) Accept device information at the login fallback endpoint. (#7629) Convert user directory handler and related classes to async/await. (#7640) Add an option to disable autojoin for guest accounts (#6637) Clarifications to the admin api documentation (#7647) Update to the stable SSO prefix for UI Auth. (#7630) Fix type information on `assert_*_is_admin` methods (#7645) Remove some unused constants. (#7644) Typo fixes. Allow new users to be registered via the admin API even if the monthly active user limit has been reached (#7263) Add device management to admin API (#7481) Attempt to fix PhoneHomeStatsTestCase.test_performance_100 being flaky. (#7634) Support CS API v0.6.0 (#6585) ...
Diffstat (limited to 'synapse')
56 files changed, 1390 insertions, 1012 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index f0105d3e2f..1d9d85a727 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.14.0" +__version__ = "1.15.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b2d26f5915..d9e7736b8d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -515,16 +515,16 @@ class Auth(object): request.authenticated_entity = service.sender return defer.succeed(service) - def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. Args: - user (UserID): user to check + user: user to check Returns: - bool: True if the user is an admin + True if the user is an admin """ - return self.store.is_server_admin(user) + return await self.store.is_server_admin(user) def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, diff --git a/synapse/api/constants.py b/synapse/api/constants.py index bcaf2c3600..5ec4a77ccd 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -61,13 +61,9 @@ class LoginType(object): MSISDN = "m.login.msisdn" RECAPTCHA = "m.login.recaptcha" TERMS = "m.login.terms" - SSO = "org.matrix.login.sso" + SSO = "m.login.sso" DUMMY = "m.login.dummy" - # Only for C/S API v1 - APPLICATION_SERVICE = "m.login.application_service" - SHARED_SECRET = "org.matrix.login.shared_secret" - class EventTypes(object): Member = "m.room.member" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 7a049b3af7..ec6b3a69a2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,75 +17,157 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + clock: A homeserver clock, for retrieving the current time + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ - def __init__(self): - self.message_counts = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) - ) + # Override default values if set + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count + + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Find out when the count of existing actions expires + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that have not exceeded their defined + rate_hz limit + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start, rate_hz = self.actions[key] + + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: - break + if action_count - time_delta * rate_hz > 0: + continue else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError + + Args: + key: An arbitrary key used to classify an action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 5afe52f8d4..f3ec2a34ec 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -863,9 +863,24 @@ class FederationSenderHandler(object): a FEDERATION_ACK back to the master, and stores the token that we have processed in `federation_stream_position` so that we can restart where we left off. """ - try: - self.federation_position = token + self.federation_position = token + + # We save and send the ACK to master asynchronously, so we don't block + # processing on persistence. We don't need to do this operation for + # every single RDATA we receive, we just need to do it periodically. + + if self._fed_position_linearizer.is_queued(None): + # There is already a task queued up to save and send the token, so + # no need to queue up another task. + return + + run_as_background_process("_save_and_send_ack", self._save_and_send_ack) + async def _save_and_send_ack(self): + """Save the current federation position in the database and send an ACK + to master with where we're up to. + """ + try: # We linearize here to ensure we don't have races updating the token # # XXX this appears to be redundant, since the ReplicationCommandHandler @@ -875,16 +890,18 @@ class FederationSenderHandler(object): # we're not being re-entered? with (await self._fed_position_linearizer.queue(None)): + # We persist and ack the same position, so we take a copy of it + # here as otherwise it can get modified from underneath us. + current_position = self.federation_position + await self.store.update_federation_out_pos( - "federation", self.federation_position + "federation", current_position ) # We ACK this token over replication so that the master can drop # its in memory queues - self._hs.get_tcp_replication().send_federation_ack( - self.federation_position - ) - self._last_ack = self.federation_position + self._hs.get_tcp_replication().send_federation_ack(current_position) + self._last_ack = current_position except Exception: logger.exception("Error updating federation stream position") diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 93a5ba2100..8454d74858 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -488,6 +488,29 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): if uptime < 0: uptime = 0 + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + stats["homeserver"] = hs.config.server_name stats["server_context"] = hs.config.server_context stats["timestamp"] = now @@ -523,25 +546,6 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["event_cache_size"] = hs.config.caches.event_cache_size # - # Performance statistics - # - old = stats_process[0] - new = (now, resource.getrusage(resource.RUSAGE_SELF)) - stats_process[0] = new - - # Get RSS in bytes - stats["memory_rss"] = new[1].ru_maxrss - - # Get CPU time in % of a single core, not % of all cores - used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( - old[1].ru_utime + old[1].ru_stime - ) - if used_cpu_time == 0 or new[0] == old[0]: - stats["cpu_average"] = 0 - else: - stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) - - # # Database version # @@ -617,18 +621,17 @@ def run(hs): clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() - @defer.inlineCallbacks - def generate_monthly_active_users(): + async def generate_monthly_active_users(): current_mau_count = 0 current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - current_mau_count = yield store.get_monthly_active_count() + current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( - yield store.get_monthly_active_count_by_service() + await store.get_monthly_active_count_by_service() ) - reserved_users = yield store.get_registered_reserved_users() + reserved_users = await store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) for app_service, count in current_mau_count_by_service.items(): diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 586038078f..e24dd637bc 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -55,7 +55,6 @@ class OIDCConfig(Config): self.oidc_token_endpoint = oidc_config.get("token_endpoint") self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") - self.oidc_subject_claim = oidc_config.get("subject_claim", "sub") self.oidc_skip_verification = oidc_config.get("skip_verification", False) ump_config = oidc_config.get("user_mapping_provider", {}) @@ -86,92 +85,119 @@ class OIDCConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # Enable OpenID Connect for registration and login. Uses authlib. + # OpenID Connect integration. The following settings can be used to make Synapse + # use an OpenID Connect Provider for authentication, instead of its internal + # password database. + # + # See https://github.com/matrix-org/synapse/blob/master/openid.md. # oidc_config: - # enable OpenID Connect. Defaults to false. - # - #enabled: true - - # use the OIDC discovery mechanism to discover endpoints. Defaults to true. - # - #discover: true - - # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. Required. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. Required. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are "client_secret_basic" (default), "client_secret_post" and "none". - # - #client_auth_method: "client_secret_basic" - - # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"]. - # - #scopes: ["openid"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # Uncomment the following to enable authorization against an OpenID Connect + # server. Defaults to false. + # + #enabled: true + + # Uncomment the following to disable use of the OIDC discovery mechanism to + # discover endpoints. Defaults to true. + # + #discover: false + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + # + # Required if 'enabled' is true. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. + # + # Required if 'enabled' is true. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. + # + # Required if 'enabled' is true. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + #client_auth_method: client_secret_post + + # list of scopes to request. This should normally include the "openid" scope. + # Defaults to ["openid"]. + # + #scopes: ["openid", "profile"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # Uncomment to skip metadata verification. Defaults to false. + # + # Use this if you are connecting to a provider that is not OpenID Connect + # compliant. + # Avoid this in production. + # + #skip_verification: true + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is {mapping_provider!r}. # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked. + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. # - #userinfo_endpoint: "https://accounts.example.com/userinfo" + #module: mapping_provider.OidcMappingProvider - # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used. + # Custom configuration values for the module. This section will be passed as + # a Python dictionary to the user mapping provider module's `parse_config` + # method. # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # skip metadata verification. Defaults to false. - # Use this if you are connecting to a provider that is not OpenID Connect compliant. - # Avoid this in production. + # The examples below are intended for the default provider: they should be + # changed if using a custom provider. # - #skip_verification: false - + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is {mapping_provider!r}. + # Jinja2 template for the localpart of the MXID. + # + # When rendering, this template is given the following variables: + # * user: The claims returned by the UserInfo Endpoint and/or in the ID + # Token + # + # This must be configured if using the default mapping provider. # - #module: mapping_provider.OidcMappingProvider + localpart_template: "{{{{ user.preferred_username }}}}" - # Custom configuration values for the module. Below options are intended - # for the built-in provider, they should be changed if using a custom - # module. This section will be passed as a Python dictionary to the - # module's `parse_config` method. + # Jinja2 template for the display name to set on first login. # - # Below is the config of the default mapping provider, based on Jinja2 - # templates. Those templates are used to render user attributes, where the - # userinfo object is available through the `user` variable. + # If unset, no displayname will be set. # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID - # - localpart_template: "{{{{ user.preferred_username }}}}" - - # Jinja2 template for the display name to set on first login. Optional. - # - #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index dbc3dd7a2c..b1981d4d15 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + from ._base import Config class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ac71c09775..a46b3ef53e 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -153,6 +153,7 @@ class RegistrationConfig(Config): if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) @@ -455,6 +456,13 @@ class RegistrationConfig(Config): # #autocreate_auto_join_rooms: true + # When auto_join_rooms is specified, setting this flag to false prevents + # guest accounts from being automatically joined to the rooms. + # + # Defaults to true. + # + #auto_join_rooms_for_guests: false + # Rewrite identity server URLs with a map from one URL to another. Applies to URLs # provided by clients (which have https:// prepended) and those specified # in `account_threepid_delegates`. URLs should not feature a trailing slash. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 944ea80e17..0ad09feef4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -70,6 +70,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg") png_thumbnail = ThumbnailRequirement(width, height, method, "image/png") requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail) + requirements.setdefault("image/webp", []).append(jpeg_thumbnail) requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 726a27d7b2..d0a19751e8 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -15,8 +15,8 @@ # limitations under the License. import logging -import os +import jinja2 import pkg_resources from synapse.python_dependencies import DependencyException, check_requirements @@ -167,9 +167,11 @@ class SAML2Config(Config): if not template_dir: template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.saml2_error_html_content = self.read_file( - os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error", - ) + loader = jinja2.FileSystemLoader(template_dir) + # enable auto-escape here, to having to remember to escape manually in the + # template + env = jinja2.Environment(loader=loader, autoescape=True) + self.saml2_error_html_template = env.get_template("saml_error.html") def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set @@ -216,6 +218,8 @@ class SAML2Config(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ + ## Single sign-on integration ## + # Enable SAML2 for registration and login. Uses pysaml2. # # At least one of `sp_config` or `config_path` must be set in this section to @@ -349,7 +353,13 @@ class SAML2Config(Config): # * HTML page to display to users if something goes wrong during the # authentication process: 'saml_error.html'. # - # This template doesn't currently need any variable to render. + # When rendering, this template is given the following variables: + # * code: an HTML error code corresponding to the error that is being + # returned (typically 400 or 500) + # + # * msg: a textual message describing the error. + # + # The variables will automatically be HTML-escaped. # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates diff --git a/synapse/config/sso.py b/synapse/config/sso.py index aff642f015..73b7296399 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -61,7 +61,8 @@ class SSOConfig(Config): def generate_config_section(self, **kwargs): return """\ - # Additional settings to use with single-sign on systems such as SAML2 and CAS. + # Additional settings to use with single-sign on systems such as OpenID Connect, + # SAML2 and CAS. # sso: # A list of client URLs which are whitelisted so that the user does not diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4acb4fa489..8a9de913b3 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -19,8 +19,6 @@ import logging from six import string_types -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -51,8 +49,7 @@ class GroupsServerWorkerHandler(object): self.transport_client = hs.get_federation_transport_client() self.profile_handler = hs.get_profile_handler() - @defer.inlineCallbacks - def check_group_is_ours( + async def check_group_is_ours( self, group_id, requester_user_id, and_exists=False, and_is_admin=None ): """Check that the group is ours, and optionally if it exists. @@ -68,25 +65,24 @@ class GroupsServerWorkerHandler(object): if not self.is_mine_id(group_id): raise SynapseError(400, "Group not on this server") - group = yield self.store.get_group(group_id) + group = await self.store.get_group(group_id) if and_exists and not group: raise SynapseError(404, "Unknown group") - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) if group and not is_user_in_group and not group["is_public"]: raise SynapseError(404, "Unknown group") if and_is_admin: - is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) + is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) if not is_admin: raise SynapseError(403, "User is not admin in group") return group - @defer.inlineCallbacks - def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary(self, group_id, requester_user_id): """Get the summary for a group as seen by requester_user_id. The group summary consists of the profile of the room, and a curated @@ -95,28 +91,28 @@ class GroupsServerWorkerHandler(object): A user/room may appear in multiple roles/categories. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) - profile = yield self.get_group_profile(group_id, requester_user_id) + profile = await self.get_group_profile(group_id, requester_user_id) - users, roles = yield self.store.get_users_for_summary_by_role( + users, roles = await self.store.get_users_for_summary_by_role( group_id, include_private=is_user_in_group ) # TODO: Add profiles to users - rooms, categories = yield self.store.get_rooms_for_summary_by_category( + rooms, categories = await self.store.get_rooms_for_summary_by_category( group_id, include_private=is_user_in_group ) for room_entry in rooms: room_id = room_entry["room_id"] - joined_users = yield self.store.get_users_in_room(room_id) - entry = yield self.room_list_handler.generate_room_entry( + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) entry = dict(entry) # so we don't change whats cached @@ -130,7 +126,7 @@ class GroupsServerWorkerHandler(object): user_id = entry["user_id"] if not self.is_mine_id(requester_user_id): - attestation = yield self.store.get_remote_attestation(group_id, user_id) + attestation = await self.store.get_remote_attestation(group_id, user_id) if not attestation: continue @@ -140,12 +136,12 @@ class GroupsServerWorkerHandler(object): group_id, user_id ) - user_profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile = await self.profile_handler.get_profile_from_cache(user_id) entry.update(user_profile) users.sort(key=lambda e: e.get("order", 0)) - membership_info = yield self.store.get_users_membership_info_in_group( + membership_info = await self.store.get_users_membership_info_in_group( group_id, requester_user_id ) @@ -164,22 +160,20 @@ class GroupsServerWorkerHandler(object): "user": membership_info, } - @defer.inlineCallbacks - def get_group_categories(self, group_id, requester_user_id): + async def get_group_categories(self, group_id, requester_user_id): """Get all categories in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - categories = yield self.store.get_group_categories(group_id=group_id) + categories = await self.store.get_group_categories(group_id=group_id) return {"categories": categories} - @defer.inlineCallbacks - def get_group_category(self, group_id, requester_user_id, category_id): + async def get_group_category(self, group_id, requester_user_id, category_id): """Get a specific category in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = yield self.store.get_group_category( + res = await self.store.get_group_category( group_id=group_id, category_id=category_id ) @@ -187,32 +181,29 @@ class GroupsServerWorkerHandler(object): return res - @defer.inlineCallbacks - def get_group_roles(self, group_id, requester_user_id): + async def get_group_roles(self, group_id, requester_user_id): """Get all roles in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - roles = yield self.store.get_group_roles(group_id=group_id) + roles = await self.store.get_group_roles(group_id=group_id) return {"roles": roles} - @defer.inlineCallbacks - def get_group_role(self, group_id, requester_user_id, role_id): + async def get_group_role(self, group_id, requester_user_id, role_id): """Get a specific role in a group (as seen by user) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) + res = await self.store.get_group_role(group_id=group_id, role_id=role_id) return res - @defer.inlineCallbacks - def get_group_profile(self, group_id, requester_user_id): + async def get_group_profile(self, group_id, requester_user_id): """Get the group profile as seen by requester_user_id """ - yield self.check_group_is_ours(group_id, requester_user_id) + await self.check_group_is_ours(group_id, requester_user_id) - group = yield self.store.get_group(group_id) + group = await self.store.get_group(group_id) if group: cols = [ @@ -229,20 +220,19 @@ class GroupsServerWorkerHandler(object): else: raise SynapseError(404, "Unknown group") - @defer.inlineCallbacks - def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group(self, group_id, requester_user_id): """Get the users in group as seen by requester_user_id. The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) - user_results = yield self.store.get_users_in_group( + user_results = await self.store.get_users_in_group( group_id, include_private=is_user_in_group ) @@ -254,14 +244,14 @@ class GroupsServerWorkerHandler(object): entry = {"user_id": g_user_id} - profile = yield self.profile_handler.get_profile_from_cache(g_user_id) + profile = await self.profile_handler.get_profile_from_cache(g_user_id) entry.update(profile) entry["is_public"] = bool(is_public) entry["is_privileged"] = bool(is_privileged) if not self.is_mine_id(g_user_id): - attestation = yield self.store.get_remote_attestation( + attestation = await self.store.get_remote_attestation( group_id, g_user_id ) if not attestation: @@ -279,30 +269,29 @@ class GroupsServerWorkerHandler(object): return {"chunk": chunk, "total_user_count_estimate": len(user_results)} - @defer.inlineCallbacks - def get_invited_users_in_group(self, group_id, requester_user_id): + async def get_invited_users_in_group(self, group_id, requester_user_id): """Get the users that have been invited to a group as seen by requester_user_id. The ordering is arbitrary at the moment """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) if not is_user_in_group: raise SynapseError(403, "User not in group") - invited_users = yield self.store.get_invited_users_in_group(group_id) + invited_users = await self.store.get_invited_users_in_group(group_id) user_profiles = [] for user_id in invited_users: user_profile = {"user_id": user_id} try: - profile = yield self.profile_handler.get_profile_from_cache(user_id) + profile = await self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) except Exception as e: logger.warning("Error getting profile for %s: %s", user_id, e) @@ -310,20 +299,19 @@ class GroupsServerWorkerHandler(object): return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - @defer.inlineCallbacks - def get_rooms_in_group(self, group_id, requester_user_id): + async def get_rooms_in_group(self, group_id, requester_user_id): """Get the rooms in group as seen by requester_user_id This returns rooms in order of decreasing number of joined users """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) - room_results = yield self.store.get_rooms_in_group( + room_results = await self.store.get_rooms_in_group( group_id, include_private=is_user_in_group ) @@ -331,8 +319,8 @@ class GroupsServerWorkerHandler(object): for room_result in room_results: room_id = room_result["room_id"] - joined_users = yield self.store.get_users_in_room(room_id) - entry = yield self.room_list_handler.generate_room_entry( + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) @@ -355,13 +343,12 @@ class GroupsServerHandler(GroupsServerWorkerHandler): # Ensure attestations get renewed hs.get_groups_attestation_renewer() - @defer.inlineCallbacks - def update_group_summary_room( + async def update_group_summary_room( self, group_id, requester_user_id, room_id, category_id, content ): """Add/update a room to the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -371,7 +358,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_summary( + await self.store.add_room_to_summary( group_id=group_id, room_id=room_id, category_id=category_id, @@ -381,31 +368,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - @defer.inlineCallbacks - def delete_group_summary_room( + async def delete_group_summary_room( self, group_id, requester_user_id, room_id, category_id ): """Remove a room from the summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_summary( + await self.store.remove_room_from_summary( group_id=group_id, room_id=room_id, category_id=category_id ) return {} - @defer.inlineCallbacks - def set_group_join_policy(self, group_id, requester_user_id, content): + async def set_group_join_policy(self, group_id, requester_user_id, content): """Sets the group join policy. Currently supported policies are: - "invite": an invite must be received and accepted in order to join. - "open": anyone can join. """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -413,22 +398,23 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if join_policy is None: raise SynapseError(400, "No value specified for 'm.join_policy'") - yield self.store.set_group_join_policy(group_id, join_policy=join_policy) + await self.store.set_group_join_policy(group_id, join_policy=join_policy) return {} - @defer.inlineCallbacks - def update_group_category(self, group_id, requester_user_id, category_id, content): + async def update_group_category( + self, group_id, requester_user_id, category_id, content + ): """Add/Update a group category """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) profile = content.get("profile") - yield self.store.upsert_group_category( + await self.store.upsert_group_category( group_id=group_id, category_id=category_id, is_public=is_public, @@ -437,25 +423,23 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - @defer.inlineCallbacks - def delete_group_category(self, group_id, requester_user_id, category_id): + async def delete_group_category(self, group_id, requester_user_id, category_id): """Delete a group category """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_category( + await self.store.remove_group_category( group_id=group_id, category_id=category_id ) return {} - @defer.inlineCallbacks - def update_group_role(self, group_id, requester_user_id, role_id, content): + async def update_group_role(self, group_id, requester_user_id, role_id, content): """Add/update a role in a group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -463,31 +447,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): profile = content.get("profile") - yield self.store.upsert_group_role( + await self.store.upsert_group_role( group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) return {} - @defer.inlineCallbacks - def delete_group_role(self, group_id, requester_user_id, role_id): + async def delete_group_role(self, group_id, requester_user_id, role_id): """Remove role from group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_role(group_id=group_id, role_id=role_id) + await self.store.remove_group_role(group_id=group_id, role_id=role_id) return {} - @defer.inlineCallbacks - def update_group_summary_user( + async def update_group_summary_user( self, group_id, requester_user_id, user_id, role_id, content ): """Add/update a users entry in the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -495,7 +477,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_public = _parse_visibility_from_contents(content) - yield self.store.add_user_to_summary( + await self.store.add_user_to_summary( group_id=group_id, user_id=user_id, role_id=role_id, @@ -505,25 +487,25 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - @defer.inlineCallbacks - def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): + async def delete_group_summary_user( + self, group_id, requester_user_id, user_id, role_id + ): """Remove a user from the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_user_from_summary( + await self.store.remove_user_from_summary( group_id=group_id, user_id=user_id, role_id=role_id ) return {} - @defer.inlineCallbacks - def update_group_profile(self, group_id, requester_user_id, content): + async def update_group_profile(self, group_id, requester_user_id, content): """Update the group profile """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -535,40 +517,38 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise SynapseError(400, "%r value is not a string" % (keyname,)) profile[keyname] = value - yield self.store.update_group_profile(group_id, profile) + await self.store.update_group_profile(group_id, profile) - @defer.inlineCallbacks - def add_room_to_group(self, group_id, requester_user_id, room_id, content): + async def add_room_to_group(self, group_id, requester_user_id, room_id, content): """Add room to group """ RoomID.from_string(room_id) # Ensure valid room id - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + await self.store.add_room_to_group(group_id, room_id, is_public=is_public) return {} - @defer.inlineCallbacks - def update_room_in_group( + async def update_room_in_group( self, group_id, requester_user_id, room_id, config_key, content ): """Update room in group """ RoomID.from_string(room_id) # Ensure valid room id - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) if config_key == "m.visibility": is_public = _parse_visibility_dict(content) - yield self.store.update_room_in_group_visibility( + await self.store.update_room_in_group_visibility( group_id, room_id, is_public=is_public ) else: @@ -576,36 +556,34 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - @defer.inlineCallbacks - def remove_room_from_group(self, group_id, requester_user_id, room_id): + async def remove_room_from_group(self, group_id, requester_user_id, room_id): """Remove room from group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_group(group_id, room_id) + await self.store.remove_room_from_group(group_id, room_id) return {} - @defer.inlineCallbacks - def invite_to_group(self, group_id, user_id, requester_user_id, content): + async def invite_to_group(self, group_id, user_id, requester_user_id, content): """Invite user to group """ - group = yield self.check_group_is_ours( + group = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) # TODO: Check if user knocked - invited_users = yield self.store.get_invited_users_in_group(group_id) + invited_users = await self.store.get_invited_users_in_group(group_id) if user_id in invited_users: raise SynapseError( 400, "User already invited to group", errcode=Codes.BAD_STATE ) - user_results = yield self.store.get_users_in_group( + user_results = await self.store.get_users_in_group( group_id, include_private=True ) if user_id in (user_result["user_id"] for user_result in user_results): @@ -618,18 +596,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - res = yield groups_local.on_invite(group_id, user_id, content) + res = await groups_local.on_invite(group_id, user_id, content) local_attestation = None else: local_attestation = self.attestations.create_attestation(group_id, user_id) content.update({"attestation": local_attestation}) - res = yield self.transport_client.invite_to_group_notification( + res = await self.transport_client.invite_to_group_notification( get_domain_from_id(user_id), group_id, user_id, content ) user_profile = res.get("user_profile", {}) - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -639,13 +617,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if not self.hs.is_mine_id(user_id): remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=user_id, group_id=group_id ) else: remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, user_id, is_admin=False, @@ -654,15 +632,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): remote_attestation=remote_attestation, ) elif res["state"] == "invite": - yield self.store.add_group_invite(group_id, user_id) + await self.store.add_group_invite(group_id, user_id) return {"state": "invite"} elif res["state"] == "reject": return {"state": "reject"} else: raise SynapseError(502, "Unknown state returned by HS") - @defer.inlineCallbacks - def _add_user(self, group_id, user_id, content): + async def _add_user(self, group_id, user_id, content): """Add a user to a group based on a content dict. See accept_invite, join_group. @@ -672,7 +649,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=user_id, group_id=group_id ) else: @@ -681,7 +658,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_public = _parse_visibility_from_contents(content) - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, user_id, is_admin=False, @@ -692,59 +669,55 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return local_attestation - @defer.inlineCallbacks - def accept_invite(self, group_id, requester_user_id, content): + async def accept_invite(self, group_id, requester_user_id, content): """User tries to accept an invite to the group. This is different from them asking to join, and so should error if no invite exists (and they're not a member of the group) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_invited = yield self.store.is_user_invited_to_local_group( + is_invited = await self.store.is_user_invited_to_local_group( group_id, requester_user_id ) if not is_invited: raise SynapseError(403, "User not invited to group") - local_attestation = yield self._add_user(group_id, requester_user_id, content) + local_attestation = await self._add_user(group_id, requester_user_id, content) return {"state": "join", "attestation": local_attestation} - @defer.inlineCallbacks - def join_group(self, group_id, requester_user_id, content): + async def join_group(self, group_id, requester_user_id, content): """User tries to join the group. This will error if the group requires an invite/knock to join """ - group_info = yield self.check_group_is_ours( + group_info = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") - local_attestation = yield self._add_user(group_id, requester_user_id, content) + local_attestation = await self._add_user(group_id, requester_user_id, content) return {"state": "join", "attestation": local_attestation} - @defer.inlineCallbacks - def knock(self, group_id, requester_user_id, content): + async def knock(self, group_id, requester_user_id, content): """A user requests becoming a member of the group """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() - @defer.inlineCallbacks - def accept_knock(self, group_id, requester_user_id, content): + async def accept_knock(self, group_id, requester_user_id, content): """Accept a users knock to the room. Errors if the user hasn't knocked, rather than inviting them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() @@ -872,8 +845,6 @@ class GroupsServerHandler(GroupsServerWorkerHandler): group_id (str) request_user_id (str) - Returns: - Deferred """ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3b781d9836..61dc4beafe 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +44,26 @@ class BaseHandler(object): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = self.hs.config.rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -70,7 +85,6 @@ class BaseHandler(object): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -83,48 +97,32 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( + # Override rate and burst count per-user + self.request_ratelimiter.ratelimit( user_id, - time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 75b39e878c..119678e67b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,11 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - self._failed_uia_attempts_ratelimiter = Ratelimiter() + self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() @@ -196,13 +200,7 @@ class AuthHandler(BaseHandler): user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,14 +210,8 @@ class AuthHandler(BaseHandler): flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 29a19b4572..230d170258 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from six import iteritems, itervalues @@ -30,7 +31,11 @@ from synapse.api.errors import ( ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.types import ( + RoomStreamToken, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -704,22 +709,27 @@ class DeviceListUpdater(object): need_resync = yield self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: - # Try to resync the current user's devices list. Exception handling - # isn't necessary here, since user_device_resync catches all instances - # of "Exception" that might be raised from the federation request. This - # means that if an exception is raised by this function, it must be - # because of a database issue, which means _maybe_retry_device_resync - # probably won't be able to go much further anyway. - result = yield self.user_device_resync( - user_id=user_id, mark_failed_as_stale=False, - ) - # user_device_resync only returns a result if it managed to successfully - # resync and update the database. Updating the table of users requiring - # resync isn't necessary here as user_device_resync already does it - # (through self.store.update_remote_device_list_cache). - if result: + try: + # Try to resync the current user's devices list. + result = yield self.user_device_resync( + user_id=user_id, mark_failed_as_stale=False, + ) + + # user_device_resync only returns a result if it managed to + # successfully resync and update the database. Updating the table + # of users requiring resync isn't necessary here as + # user_device_resync already does it (through + # self.store.update_remote_device_list_cache). + if result: + logger.debug( + "Successfully resynced the device list for %s", user_id, + ) + except Exception as e: + # If there was an issue resyncing this user, e.g. if the remote + # server sent a malformed result, just log the error instead of + # aborting all the subsequent resyncs. logger.debug( - "Successfully resynced the device list for %s" % user_id, + "Could not resync the device list for %s: %s", user_id, e, ) finally: # Allow future calls to retry resyncinc out of sync device lists. @@ -738,6 +748,7 @@ class DeviceListUpdater(object): request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ + logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) # Fetch all devices for the user. origin = get_domain_from_id(user_id) @@ -789,6 +800,13 @@ class DeviceListUpdater(object): stream_id = result["stream_id"] devices = result["devices"] + # Get the master key and the self-signing key for this user if provided in the + # response (None if not in the response). + # The response will not contain the user signing key, as this key is only used by + # its owner, thus it doesn't make sense to send it over federation. + master_key = result.get("master_key") + self_signing_key = result.get("self_signing_key") + # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -818,6 +836,13 @@ class DeviceListUpdater(object): yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] + + # Handle cross-signing keys. + cross_signing_device_ids = yield self.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + cross_signing_device_ids + yield self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given @@ -825,3 +850,40 @@ class DeviceListUpdater(object): self._seen_updates[user_id] = {stream_id} defer.returnValue(result) + + @defer.inlineCallbacks + def process_cross_signing_key_update( + self, + user_id: str, + master_key: Optional[Dict[str, Any]], + self_signing_key: Optional[Dict[str, Any]], + ) -> list: + """Process the given new master and self-signing key for the given remote user. + + Args: + user_id: The ID of the user these keys are for. + master_key: The dict of the cross-signing master key as returned by the + remote server. + self_signing_key: The dict of the cross-signing self-signing key as returned + by the remote server. + + Return: + The device IDs for the given keys. + """ + device_ids = [] + + if master_key: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) + + return device_ids diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8f1bc0323c..774a252619 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object): """ device_handler = self.e2e_keys_handler.device_handler + device_list_updater = device_handler.device_list_updater with (yield self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) @@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - if master_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "master", master_key - ) - _, verify_key = get_verify_key_from_cross_signing_key(master_key) - # verify_key is a VerifyKey from signedjson, which uses - # .version to denote the portion of the key ID after the - # algorithm and colon, which is the device ID - device_ids.append(verify_key.version) - if self_signing_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key - ) - _, verify_key = get_verify_key_from_cross_signing_key( - self_signing_key - ) - device_ids.append(verify_key.version) + new_device_ids = yield device_list_updater.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + new_device_ids yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eec8066eeb..bbf23345e2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -504,7 +504,7 @@ class FederationHandler(BaseHandler): min_depth=min_depth, timeout=60000, ) - except RequestSendFailed as e: + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ca5c83811a..ebe8d25bd8 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -18,8 +18,6 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import get_domain_from_id @@ -92,19 +90,18 @@ class GroupsLocalWorkerHandler(object): get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - @defer.inlineCallbacks - def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary(self, group_id, requester_user_id): """Get the group summary for a group. If the group is remote we check that the users have valid attestations. """ if self.is_mine_id(group_id): - res = yield self.groups_server_handler.get_group_summary( + res = await self.groups_server_handler.get_group_summary( group_id, requester_user_id ) else: try: - res = yield self.transport_client.get_group_summary( + res = await self.transport_client.get_group_summary( get_domain_from_id(group_id), group_id, requester_user_id ) except HttpResponseException as e: @@ -122,7 +119,7 @@ class GroupsLocalWorkerHandler(object): attestation = entry.pop("attestation", {}) try: if get_domain_from_id(g_user_id) != group_server_name: - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, group_id=group_id, user_id=g_user_id, @@ -139,19 +136,18 @@ class GroupsLocalWorkerHandler(object): # Add `is_publicised` flag to indicate whether the user has publicised their # membership of the group on their profile - result = yield self.store.get_publicised_groups_for_user(requester_user_id) + result = await self.store.get_publicised_groups_for_user(requester_user_id) is_publicised = group_id in result res.setdefault("user", {})["is_publicised"] = is_publicised return res - @defer.inlineCallbacks - def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group(self, group_id, requester_user_id): """Get users in a group """ if self.is_mine_id(group_id): - res = yield self.groups_server_handler.get_users_in_group( + res = await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) return res @@ -159,7 +155,7 @@ class GroupsLocalWorkerHandler(object): group_server_name = get_domain_from_id(group_id) try: - res = yield self.transport_client.get_users_in_group( + res = await self.transport_client.get_users_in_group( get_domain_from_id(group_id), group_id, requester_user_id ) except HttpResponseException as e: @@ -174,7 +170,7 @@ class GroupsLocalWorkerHandler(object): attestation = entry.pop("attestation", {}) try: if get_domain_from_id(g_user_id) != group_server_name: - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, group_id=group_id, user_id=g_user_id, @@ -188,15 +184,13 @@ class GroupsLocalWorkerHandler(object): return res - @defer.inlineCallbacks - def get_joined_groups(self, user_id): - group_ids = yield self.store.get_joined_groups(user_id) + async def get_joined_groups(self, user_id): + group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - @defer.inlineCallbacks - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id): if self.hs.is_mine_id(user_id): - result = yield self.store.get_publicised_groups_for_user(user_id) + result = await self.store.get_publicised_groups_for_user(user_id) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) @@ -206,7 +200,7 @@ class GroupsLocalWorkerHandler(object): return {"groups": result} else: try: - bulk_result = yield self.transport_client.bulk_get_publicised_groups( + bulk_result = await self.transport_client.bulk_get_publicised_groups( get_domain_from_id(user_id), [user_id] ) except HttpResponseException as e: @@ -218,8 +212,7 @@ class GroupsLocalWorkerHandler(object): # TODO: Verify attestations return {"groups": result} - @defer.inlineCallbacks - def bulk_get_publicised_groups(self, user_ids, proxy=True): + async def bulk_get_publicised_groups(self, user_ids, proxy=True): destinations = {} local_users = set() @@ -236,7 +229,7 @@ class GroupsLocalWorkerHandler(object): failed_results = [] for destination, dest_user_ids in iteritems(destinations): try: - r = yield self.transport_client.bulk_get_publicised_groups( + r = await self.transport_client.bulk_get_publicised_groups( destination, list(dest_user_ids) ) results.update(r["users"]) @@ -244,7 +237,7 @@ class GroupsLocalWorkerHandler(object): failed_results.extend(dest_user_ids) for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user(uid) + results[uid] = await self.store.get_publicised_groups_for_user(uid) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) @@ -333,12 +326,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - @defer.inlineCallbacks - def join_group(self, group_id, user_id, content): + async def join_group(self, group_id, user_id, content): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group(group_id, user_id, content) + await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -346,7 +338,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): content["attestation"] = local_attestation try: - res = yield self.transport_client.join_group( + res = await self.transport_client.join_group( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -356,7 +348,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -366,7 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # TODO: Check that the group is public and we're being added publically is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -379,12 +371,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - @defer.inlineCallbacks - def accept_invite(self, group_id, user_id, content): + async def accept_invite(self, group_id, user_id, content): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite(group_id, user_id, content) + await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -392,7 +383,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): content["attestation"] = local_attestation try: - res = yield self.transport_client.accept_group_invite( + res = await self.transport_client.accept_group_invite( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -402,7 +393,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -412,7 +403,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # TODO: Check that the group is public and we're being added publically is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -425,18 +416,17 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - @defer.inlineCallbacks - def invite(self, group_id, user_id, requester_user_id, config): + async def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): - res = yield self.groups_server_handler.invite_to_group( + res = await self.groups_server_handler.invite_to_group( group_id, user_id, requester_user_id, content ) else: try: - res = yield self.transport_client.invite_to_group( + res = await self.transport_client.invite_to_group( get_domain_from_id(group_id), group_id, user_id, @@ -450,8 +440,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - @defer.inlineCallbacks - def on_invite(self, group_id, user_id, content): + async def on_invite(self, group_id, user_id, content): """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -466,7 +455,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): if "avatar_url" in content["profile"]: local_profile["avatar_url"] = content["profile"]["avatar_url"] - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="invite", @@ -474,7 +463,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): ) self.notifier.on_new_event("groups_key", token, users=[user_id]) try: - user_profile = yield self.profile_handler.get_profile(user_id) + user_profile = await self.profile_handler.get_profile(user_id) except Exception as e: logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} @@ -516,12 +505,11 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - @defer.inlineCallbacks - def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group(self, group_id, user_id, content): """One of our users was removed/kicked from a group """ # TODO: Check if user in group - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index e0c426a13b..6039034c00 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -66,8 +66,7 @@ class IdentityHandler(BaseHandler): self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls self._enable_lookup = hs.config.enable_3pid_lookup - @defer.inlineCallbacks - def threepid_from_creds(self, id_server_url, creds): + async def threepid_from_creds(self, id_server_url, creds): """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server @@ -110,7 +109,7 @@ class IdentityHandler(BaseHandler): ) try: - data = yield self.http_client.get_json(url, query_params) + data = await self.http_client.get_json(url, query_params) except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: @@ -133,8 +132,7 @@ class IdentityHandler(BaseHandler): logger.info("%s reported non-validated threepid: %s", id_server_url, creds) return None - @defer.inlineCallbacks - def bind_threepid( + async def bind_threepid( self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True ): """Bind a 3PID to an identity server @@ -179,12 +177,12 @@ class IdentityHandler(BaseHandler): try: # Use the blacklisting http client as this call is only to identity servers # provided by a client - data = yield self.blacklisting_http_client.post_json_get_json( + data = await self.blacklisting_http_client.post_json_get_json( bind_url, bind_data, headers=headers ) # Remember where we bound the threepid - yield self.store.add_user_bound_threepid( + await self.store.add_user_bound_threepid( user_id=mxid, medium=data["medium"], address=data["address"], @@ -203,13 +201,12 @@ class IdentityHandler(BaseHandler): return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) - res = yield self.bind_threepid( + res = await self.bind_threepid( client_secret, sid, mxid, id_server, id_access_token, use_v2=False ) return res - @defer.inlineCallbacks - def try_unbind_threepid(self, mxid, threepid): + async def try_unbind_threepid(self, mxid, threepid): """Attempt to remove a 3PID from an identity server, or if one is not provided, all identity servers we're aware the binding is present on @@ -229,7 +226,7 @@ class IdentityHandler(BaseHandler): if threepid.get("id_server"): id_servers = [threepid["id_server"]] else: - id_servers = yield self.store.get_id_servers_user_bound( + id_servers = await self.store.get_id_servers_user_bound( user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) @@ -239,14 +236,13 @@ class IdentityHandler(BaseHandler): changed = True for id_server in id_servers: - changed &= yield self.try_unbind_threepid_with_id_server( + changed &= await self.try_unbind_threepid_with_id_server( mxid, threepid, id_server ) return changed - @defer.inlineCallbacks - def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): + async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): """Removes a binding from an identity server Args: @@ -291,7 +287,7 @@ class IdentityHandler(BaseHandler): try: # Use the blacklisting http client as this call is only to identity servers # provided by a client - yield self.blacklisting_http_client.post_json_get_json( + await self.blacklisting_http_client.post_json_get_json( url, content, headers ) changed = True @@ -306,7 +302,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") - yield self.store.remove_user_bound_threepid( + await self.store.remove_user_bound_threepid( user_id=mxid, medium=threepid["medium"], address=threepid["address"], @@ -420,8 +416,7 @@ class IdentityHandler(BaseHandler): logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) return rewritten_url - @defer.inlineCallbacks - def requestEmailToken( + async def requestEmailToken( self, id_server_url, email, client_secret, send_attempt, next_link=None ): """ @@ -461,7 +456,7 @@ class IdentityHandler(BaseHandler): ) try: - data = yield self.http_client.post_json_get_json( + data = await self.http_client.post_json_get_json( "%s/_matrix/identity/api/v1/validate/email/requestToken" % (id_server_url,), params, @@ -473,8 +468,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") - @defer.inlineCallbacks - def requestMsisdnToken( + async def requestMsisdnToken( self, id_server_url, country, @@ -519,7 +513,7 @@ class IdentityHandler(BaseHandler): # apply it now. id_server_url = self.rewrite_id_server_url(id_server_url) try: - data = yield self.http_client.post_json_get_json( + data = await self.http_client.post_json_get_json( "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" % (id_server_url,), params, @@ -541,8 +535,7 @@ class IdentityHandler(BaseHandler): ) return data - @defer.inlineCallbacks - def validate_threepid_session(self, client_secret, sid): + async def validate_threepid_session(self, client_secret, sid): """Validates a threepid session with only the client secret and session ID Tries validating against any configured account_threepid_delegates as well as locally. @@ -564,12 +557,12 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: # Ask our delegated email identity server - validation_session = yield self.threepid_from_creds( + validation_session = await self.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details - validation_session = yield self.store.get_threepid_validation_session( + validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True ) @@ -579,14 +572,13 @@ class IdentityHandler(BaseHandler): # Try to validate as msisdn if self.hs.config.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server - validation_session = yield self.threepid_from_creds( + validation_session = await self.threepid_from_creds( self.hs.config.account_threepid_delegate_msisdn, threepid_creds ) return validation_session - @defer.inlineCallbacks - def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token): + async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token): """Proxy a POST submitToken request to an identity server for verification purposes Args: @@ -607,11 +599,9 @@ class IdentityHandler(BaseHandler): body = {"client_secret": client_secret, "sid": sid, "token": token} try: - return ( - yield self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", - body, - ) + return await self.http_client.post_json_get_json( + id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", + body, ) except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") @@ -663,7 +653,7 @@ class IdentityHandler(BaseHandler): logger.info("Failed to contact %s: %s", id_server, e) raise ProxiedRequestError(503, "Failed to contact identity server") - defer.returnValue(data) + return data @defer.inlineCallbacks def proxy_bulk_lookup_3pid(self, id_server, threepids): @@ -702,8 +692,7 @@ class IdentityHandler(BaseHandler): defer.returnValue(data) - @defer.inlineCallbacks - def lookup_3pid(self, id_server, medium, address, id_access_token=None): + async def lookup_3pid(self, id_server, medium, address, id_access_token=None): """Looks up a 3pid in the passed identity server. Args: @@ -722,7 +711,7 @@ class IdentityHandler(BaseHandler): if id_access_token is not None: try: - results = yield self._lookup_3pid_v2( + results = await self._lookup_3pid_v2( id_server_url, id_access_token, medium, address ) return results @@ -741,10 +730,9 @@ class IdentityHandler(BaseHandler): logger.warning("Error when looking up hashing details: %s", e) return None - return (yield self._lookup_3pid_v1(id_server, id_server_url, medium, address)) + return await self._lookup_3pid_v1(id_server, id_server_url, medium, address) - @defer.inlineCallbacks - def _lookup_3pid_v1(self, id_server, id_server_url, medium, address): + async def _lookup_3pid_v1(self, id_server, id_server_url, medium, address): """Looks up a 3pid in the passed identity server using v1 lookup. Args: @@ -758,7 +746,7 @@ class IdentityHandler(BaseHandler): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = yield self.http_client.get_json( + data = await self.http_client.get_json( "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -766,7 +754,7 @@ class IdentityHandler(BaseHandler): if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") - yield self._verify_any_signature(data, id_server) + await self._verify_any_signature(data, id_server) return data["mxid"] except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") @@ -775,8 +763,7 @@ class IdentityHandler(BaseHandler): return None - @defer.inlineCallbacks - def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address): + async def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address): """Looks up a 3pid in the passed identity server using v2 lookup. Args: @@ -790,7 +777,7 @@ class IdentityHandler(BaseHandler): """ # Check what hashing details are supported by this identity server try: - hash_details = yield self.http_client.get_json( + hash_details = await self.http_client.get_json( "%s/_matrix/identity/v2/hash_details" % (id_server_url,), {"access_token": id_access_token}, ) @@ -856,7 +843,7 @@ class IdentityHandler(BaseHandler): headers = {"Authorization": create_id_access_token_header(id_access_token)} try: - lookup_results = yield self.http_client.post_json_get_json( + lookup_results = await self.http_client.post_json_get_json( "%s/_matrix/identity/v2/lookup" % (id_server_url,), { "addresses": [lookup_value], @@ -884,15 +871,14 @@ class IdentityHandler(BaseHandler): mxid = lookup_results["mappings"].get(lookup_value) return mxid - @defer.inlineCallbacks - def _verify_any_signature(self, data, id_server): + async def _verify_any_signature(self, data, id_server): if id_server not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (id_server,)) for key_name, signature in data["signatures"][id_server].items(): id_server_url = self.rewrite_id_server_url(id_server, add_https=True) - key_data = yield self.http_client.get_json( + key_data = await self.http_client.get_json( "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name) ) if "public_key" not in key_data: @@ -910,8 +896,7 @@ class IdentityHandler(BaseHandler): raise AuthError(401, "No signature from server %s" % (id_server,)) - @defer.inlineCallbacks - def ask_id_server_for_third_party_invite( + async def ask_id_server_for_third_party_invite( self, requester, id_server, @@ -986,7 +971,7 @@ class IdentityHandler(BaseHandler): # Attempt a v2 lookup url = base_url + "/v2/store-invite" try: - data = yield self.blacklisting_http_client.post_json_get_json( + data = await self.blacklisting_http_client.post_json_get_json( url, invite_config, {"Authorization": create_id_access_token_header(id_access_token)}, @@ -1005,7 +990,7 @@ class IdentityHandler(BaseHandler): url = base_url + "/api/v1/store-invite" try: - data = yield self.blacklisting_http_client.post_json_get_json( + data = await self.blacklisting_http_client.post_json_get_json( url, invite_config ) except TimeoutError: @@ -1020,7 +1005,7 @@ class IdentityHandler(BaseHandler): # types. This is especially true with old instances of Sydent, see # https://github.com/matrix-org/sydent/pull/170 try: - data = yield self.blacklisting_http_client.post_urlencoded_get_json( + data = await self.blacklisting_http_client.post_urlencoded_get_json( url, invite_config ) except HttpResponseException as e: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 681f92cafd..649ca1f08a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -362,7 +362,6 @@ class EventCreationHandler(object): self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4ba8c7fda5..9c08eb5399 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -37,6 +37,7 @@ from twisted.web.client import readBody from synapse.config import ConfigError from synapse.http.server import finish_request from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable from synapse.push.mailer import load_jinja2_templates from synapse.server import HomeServer from synapse.types import UserID, map_username_to_mxid_localpart @@ -99,7 +100,6 @@ class OidcHandler: hs.config.oidc_client_auth_method, ) # type: ClientAuth self._client_auth_method = hs.config.oidc_client_auth_method # type: str - self._subject_claim = hs.config.oidc_subject_claim self._provider_metadata = OpenIDProviderMetadata( issuer=hs.config.oidc_issuer, authorization_endpoint=hs.config.oidc_authorization_endpoint, @@ -310,6 +310,10 @@ class OidcHandler: received in the callback to exchange it for a token. The call uses the ``ClientAuth`` to authenticate with the client with its ID and secret. + See: + https://tools.ietf.org/html/rfc6749#section-3.2 + https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + Args: code: The authorization code we got from the callback. @@ -362,7 +366,7 @@ class OidcHandler: code=response.code, phrase=response.phrase.decode("utf-8") ) - resp_body = await readBody(response) + resp_body = await make_deferred_yieldable(readBody(response)) if response.code >= 500: # In case of a server error, we should first try to decode the body @@ -484,6 +488,7 @@ class OidcHandler: claims_params=claims_params, ) except ValueError: + logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( token["id_token"], @@ -592,6 +597,9 @@ class OidcHandler: # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError error = request.args[b"error"][0].decode() description = request.args.get(b"error_description", [b""])[0].decode() @@ -605,8 +613,11 @@ class OidcHandler: self._render_error(request, error, description) return + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) + session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] if session is None: logger.info("No session cookie found") self._render_error(request, "missing_session", "No session cookie found") @@ -654,7 +665,7 @@ class OidcHandler: self._render_error(request, "invalid_request", "Code parameter is missing") return - logger.info("Exchanging code") + logger.debug("Exchanging code") code = request.args[b"code"][0].decode() try: token = await self._exchange_code(code) @@ -663,10 +674,12 @@ class OidcHandler: self._render_error(request, e.error, e.error_description) return + logger.debug("Successfully obtained OAuth2 access token") + # Now that we have a token, get the userinfo, either by decoding the # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: - logger.info("Fetching userinfo") + logger.debug("Fetching userinfo") try: userinfo = await self._fetch_userinfo(token) except Exception as e: @@ -674,7 +687,7 @@ class OidcHandler: self._render_error(request, "fetch_error", str(e)) return else: - logger.info("Extracting userinfo from id_token") + logger.debug("Extracting userinfo from id_token") try: userinfo = await self._parse_id_token(token, nonce=nonce) except Exception as e: @@ -750,7 +763,7 @@ class OidcHandler: return macaroon.serialize() def _verify_oidc_session_token( - self, session: str, state: str + self, session: bytes, state: str ) -> Tuple[str, str, Optional[str]]: """Verifies and extract an OIDC session token. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c6f61d9d1..d5d44de8d0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,8 +16,6 @@ """Contains functions for registering clients.""" import logging -from twisted.internet import defer - from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError @@ -78,8 +76,7 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime - @defer.inlineCallbacks - def check_username( + async def check_username( self, localpart, guest_access_token=None, assigned_user_id=None, ): """ @@ -128,7 +125,7 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME, ) - users = yield self.store.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( @@ -136,7 +133,7 @@ class RegistrationHandler(BaseHandler): ) # Retrieve guest user information from provided access token - user_data = yield self.auth.get_user_by_access_token(guest_access_token) + user_data = await self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, @@ -145,8 +142,16 @@ class RegistrationHandler(BaseHandler): errcode=Codes.FORBIDDEN, ) - @defer.inlineCallbacks - def register_user( + if guest_access_token is None: + try: + int(localpart) + raise SynapseError( + 400, "Numeric user IDs are reserved for guest users." + ) + except ValueError: + pass + + async def register_user( self, localpart=None, password_hash=None, @@ -158,6 +163,7 @@ class RegistrationHandler(BaseHandler): default_display_name=None, address=None, bind_emails=[], + by_admin=False, ): """Registers a new client on the server. @@ -173,29 +179,24 @@ class RegistrationHandler(BaseHandler): will be set to this. Defaults to 'localpart'. address (str|None): the IP address used to perform the registration. bind_emails (List[str]): list of emails to bind to this account. + by_admin (bool): True if this registration is being made via the + admin api, otherwise False. Returns: - Deferred[str]: user_id + str: user_id Raises: SynapseError if there was a problem registering. """ - yield self.check_registration_ratelimit(address) + self.check_registration_ratelimit(address) - yield self.auth.check_auth_blocking(threepid=threepid) + # do not check_auth_blocking if the call is coming through the Admin API + if not by_admin: + await self.auth.check_auth_blocking(threepid=threepid) if localpart is not None: - yield self.check_username(localpart, guest_access_token=guest_access_token) + await self.check_username(localpart, guest_access_token=guest_access_token) was_guest = guest_access_token is not None - if not was_guest: - try: - int(localpart) - raise SynapseError( - 400, "Numeric user IDs are reserved for guest users." - ) - except ValueError: - pass - user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -206,7 +207,7 @@ class RegistrationHandler(BaseHandler): elif default_display_name is None: default_display_name = localpart - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -218,15 +219,13 @@ class RegistrationHandler(BaseHandler): ) if default_display_name: - yield defer.ensureDeferred( - self.profile_handler.set_displayname( - user, None, default_display_name, by_admin=True - ) + await self.profile_handler.set_displayname( + user, None, default_display_name, by_admin=True ) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(localpart) + await self.user_directory_handler.handle_local_profile_change( user_id, profile ) @@ -239,14 +238,14 @@ class RegistrationHandler(BaseHandler): if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = yield self._generate_user_id() + localpart = await self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) if default_display_name is None: default_display_name = localpart try: - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, make_guest=make_guest, @@ -254,10 +253,8 @@ class RegistrationHandler(BaseHandler): address=address, ) - yield defer.ensureDeferred( - self.profile_handler.set_displayname( - user, None, default_display_name, by_admin=True - ) + await self.profile_handler.set_displayname( + user, None, default_display_name, by_admin=True ) # Successfully registered @@ -269,7 +266,13 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield defer.ensureDeferred(self._auto_join_rooms(user_id)) + if not self.hs.config.auto_join_rooms_for_guests and make_guest: + logger.info( + "Skipping auto-join for %s because auto-join for guests is disabled", + user_id, + ) + else: + await self._auto_join_rooms(user_id) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -287,15 +290,15 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self.register_email_threepid(user_id, threepid_dict, None) + await self.register_email_threepid(user_id, threepid_dict, None) # Prevent the new user from showing up in the user directory if the server # mandates it. if not self._show_in_user_directory: - yield self.store.add_account_data_for_user( + await self.store.add_account_data_for_user( user_id, "im.vector.hide_profile", {"hide_profile": True} ) - yield self.profile_handler.set_active([user], False, True) + await self.profile_handler.set_active([user], False, True) return user_id @@ -360,12 +363,10 @@ class RegistrationHandler(BaseHandler): """ await self._auto_join_rooms(user_id) - @defer.inlineCallbacks - def appservice_register( + async def appservice_register( self, user_localpart, as_token, password_hash, display_name ): # FIXME: this should be factored out and merged with normal register() - user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -380,28 +381,24 @@ class RegistrationHandler(BaseHandler): service_id = service.id if service.is_exclusive_user(user_id) else None - yield self.check_user_id_not_appservice_exclusive( - user_id, allowed_appservice=service - ) + self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) display_name = display_name or user.localpart - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, appservice_id=service_id, create_profile_with_displayname=display_name, ) - yield defer.ensureDeferred( - self.profile_handler.set_displayname( - user, None, display_name, by_admin=True - ) + await self.profile_handler.set_displayname( + user, None, display_name, by_admin=True ) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(user_localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(user_localpart) + await self.user_directory_handler.handle_local_profile_change( user_id, profile ) @@ -431,8 +428,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - @defer.inlineCallbacks - def shadow_register(self, localpart, display_name, auth_result, params): + async def shadow_register(self, localpart, display_name, auth_result, params): """Invokes the current registration on another server, using shared secret registration, passing in any auth_results from other registration UI auth flows (e.g. validated 3pids) @@ -443,7 +439,7 @@ class RegistrationHandler(BaseHandler): shadow_hs_url = self.hs.config.shadow_server.get("hs_url") as_token = self.hs.config.shadow_server.get("as_token") - yield self.http_client.post_json_get_json( + await self.http_client.post_json_get_json( "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token), { # XXX: auth_result is an unspecified extension for shadow registration @@ -463,13 +459,12 @@ class RegistrationHandler(BaseHandler): }, ) - @defer.inlineCallbacks - def _generate_user_id(self): + async def _generate_user_id(self): if self._next_generated_user_id is None: - with (yield self._generate_user_id_linearizer.queue(())): + with await self._generate_user_id_linearizer.queue(()): if self._next_generated_user_id is None: self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() + await self.store.find_next_generated_user_id_localpart() ) id = self._next_generated_user_id @@ -514,14 +509,7 @@ class RegistrationHandler(BaseHandler): if not address: return - time_now = self.clock.time() - - self.ratelimiter.ratelimit( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, @@ -579,8 +567,9 @@ class RegistrationHandler(BaseHandler): user_type=user_type, ) - @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, is_guest=False): + async def register_device( + self, user_id, device_id, initial_display_name, is_guest=False + ): """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -594,11 +583,11 @@ class RegistrationHandler(BaseHandler): is_guest (bool): Whether this is a guest account Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + tuple[str, str]: Tuple of device ID and access token """ if self.hs.config.worker_app: - r = yield self._register_device_client( + r = await self._register_device_client( user_id=user_id, device_id=device_id, initial_display_name=initial_display_name, @@ -614,7 +603,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = yield self.device_handler.check_device_registered( + device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -623,10 +612,8 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield defer.ensureDeferred( - self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms - ) + access_token = await self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms ) return (device_id, access_token) @@ -706,8 +693,7 @@ class RegistrationHandler(BaseHandler): await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - @defer.inlineCallbacks - def register_email_threepid(self, user_id, threepid, token): + async def register_email_threepid(self, user_id, threepid, token): """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -720,8 +706,6 @@ class RegistrationHandler(BaseHandler): threepid (object): m.login.email.identity auth response token (str|None): access_token for the user, or None if not logged in. - Returns: - defer.Deferred: """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -729,13 +713,8 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield defer.ensureDeferred( - self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -751,10 +730,10 @@ class RegistrationHandler(BaseHandler): # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token(token) + user_tuple = await self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -766,8 +745,7 @@ class RegistrationHandler(BaseHandler): data={}, ) - @defer.inlineCallbacks - def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id, threepid): """Add a phone number as a 3pid identifier Must be called on master. @@ -775,8 +753,6 @@ class RegistrationHandler(BaseHandler): Args: user_id (str): id of user threepid (object): m.login.msisdn auth response - Returns: - defer.Deferred: """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) @@ -787,11 +763,6 @@ class RegistrationHandler(BaseHandler): return None raise - yield defer.ensureDeferred( - self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index e75dabcd77..4cbc02b0d0 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -253,10 +253,21 @@ class RoomListHandler(BaseHandler): """ result = {"room_id": room_id, "num_joined_members": num_joined_users} + if with_alias: + aliases = yield self.store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) + if aliases: + result["aliases"] = aliases + current_state_ids = yield self.store.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) + if not current_state_ids: + # We're not in the room, so may as well bail out here. + return result + event_map = yield self.store.get_events( [ event_id @@ -289,14 +300,7 @@ class RoomListHandler(BaseHandler): create_event = current_state.get((EventTypes.Create, "")) result["m.federate"] = create_event.content.get("m.federate", True) - if with_alias: - aliases = yield self.store.get_aliases_for_room( - room_id, on_invalidate=cache_context.invalidate - ) - if aliases: - result["aliases"] = aliases - - name_event = yield current_state.get((EventTypes.Name, "")) + name_event = current_state.get((EventTypes.Name, "")) if name_event: name = name_event.content.get("name", None) if name: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index e7015c704f..abecaa8313 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -23,11 +23,9 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi -from synapse.module_api.errors import RedirectException from synapse.types import ( UserID, map_username_to_mxid_localpart, @@ -80,8 +78,6 @@ class SamlHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - self._error_html_content = hs.config.saml2_error_html_content - def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None ) -> bytes: @@ -129,26 +125,9 @@ class SamlHandler: # the dict. self.expire_sessions() - try: - user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state - ) - except RedirectException: - # Raise the exception as per the wishes of the SAML module response - raise - except Exception as e: - # If decoding the response or mapping it to a user failed, then log the - # error and tell the user that something went wrong. - logger.error(e) - - request.setResponseCode(400) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(self._error_html_content),) - ) - request.write(self._error_html_content.encode("utf8")) - finish_request(request) - return + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) # Complete the interactive auth session or the login. if current_session and current_session.ui_auth_session_id: @@ -171,6 +150,11 @@ class SamlHandler: Returns: Tuple of the user ID and SAML session associated with this response. + + Raises: + SynapseError if there was a problem with the response. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. """ try: saml2_auth = self._saml_client.parse_authn_request_response( @@ -179,11 +163,9 @@ class SamlHandler: outstanding=self._outstanding_requests_dict, ) except Exception as e: - logger.warning("Exception parsing SAML2 response: %s", e) raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) if saml2_auth.not_signed: - logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") logger.debug("SAML2 response: %s", saml2_auth.origxml) @@ -264,13 +246,13 @@ class SamlHandler: localpart = attribute_dict.get("mxid_localpart") if not localpart: - logger.error( - "SAML mapping provider plugin did not return a " - "mxid_localpart object" + raise Exception( + "Error parsing SAML2 response: SAML mapping provider plugin " + "did not return a mxid_localpart value" ) - raise SynapseError(500, "Error parsing SAML2 response") displayname = attribute_dict.get("displayname") + emails = attribute_dict.get("emails", []) # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( @@ -288,7 +270,9 @@ class SamlHandler: logger.info("Mapped SAML user to local part %s", localpart) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayname + localpart=localpart, + default_display_name=displayname, + bind_emails=emails, ) await self._datastore.record_user_external_id( @@ -381,6 +365,7 @@ class DefaultSamlMappingProvider(object): dict: A dict containing new user attributes. Possible keys: * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user + * emails (list[str]): Any emails for the user """ try: mxid_source = saml_response.ava[self._mxid_source_attribute][0] @@ -403,9 +388,13 @@ class DefaultSamlMappingProvider(object): # If displayname is None, the mxid_localpart will be used instead displayname = saml_response.ava.get("displayName", [None])[0] + # Retrieve any emails present in the saml response + emails = saml_response.ava.get("email", []) + return { "mxid_localpart": localpart, "displayname": displayname, + "emails": emails, } @staticmethod @@ -444,4 +433,4 @@ class DefaultSamlMappingProvider(object): second set consists of those attributes which can be used if available, but are not necessary """ - return {"uid", config.mxid_source_attribute}, {"displayName"} + return {"uid", config.mxid_source_attribute}, {"displayName", "email"} diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index f065970c40..8590c1eff4 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -24,8 +22,7 @@ class StateDeltasHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. @@ -41,10 +38,10 @@ class StateDeltasHandler(object): prev_event = None event = None if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index d93a276693..149f861239 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -16,17 +16,14 @@ import logging from collections import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class StatsHandler(StateDeltasHandler): +class StatsHandler: """Handles keeping the *_stats tables updated with a simple time-series of information about the users, rooms and media on the server, such that admins have some idea of who is consuming their resources. @@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler): """ def __init__(self, hs): - super(StatsHandler, self).__init__(hs) self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -68,20 +64,18 @@ class StatsHandler(StateDeltasHandler): self._is_processing = True - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False run_as_background_process("stats.notify_new_event", process) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_positions() + self.pos = await self.store.get_stats_positions() # Loop round handling deltas until we're up to date @@ -96,13 +90,13 @@ class StatsHandler(StateDeltasHandler): logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) if deltas: logger.debug("Handling %d state deltas", len(deltas)) - room_deltas, user_deltas = yield self._handle_deltas(deltas) + room_deltas, user_deltas = await self._handle_deltas(deltas) else: room_deltas = {} user_deltas = {} @@ -111,7 +105,7 @@ class StatsHandler(StateDeltasHandler): ( room_count, user_count, - ) = yield self.store.get_changes_room_total_events_and_bytes( + ) = await self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) @@ -125,7 +119,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("user_deltas: %s", user_deltas) # Always call this so that we update the stats position. - yield self.store.bulk_update_stats_delta( + await self.store.bulk_update_stats_delta( self.clock.time_msec(), updates={"room": room_deltas, "user": user_deltas}, stream_id=max_pos, @@ -137,13 +131,12 @@ class StatsHandler(StateDeltasHandler): self.pos = max_pos - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process Returns: - Deferred[tuple[dict[str, Counter], dict[str, counter]]] - Resovles to two dicts, the room deltas and the user deltas, + tuple[dict[str, Counter], dict[str, counter]] + Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ @@ -162,7 +155,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_stats("room", room_id) + token = await self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -184,7 +177,7 @@ class StatsHandler(StateDeltasHandler): sender = None if event_id is not None: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} sender = event.sender @@ -200,16 +193,16 @@ class StatsHandler(StateDeltasHandler): room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: - # we could use _get_key_change here but it's a bit inefficient - # given we're not testing for a specific result; might as well - # just grab the prev_membership and membership strings and - # compare them. + # we could use StateDeltasHandler._get_key_change here but it's + # a bit inefficient given we're not testing for a specific + # result; might as well just grab the prev_membership and + # membership strings and compare them. # We take None rather than leave as a previous membership # in the absence of a previous event because we do not want to # reduce the leave count when a new-to-the-room user joins. prev_membership = None if prev_event_id is not None: - prev_event = yield self.store.get_event( + prev_event = await self.store.get_event( prev_event_id, allow_none=True ) if prev_event: @@ -301,6 +294,6 @@ class StatsHandler(StateDeltasHandler): for room_id, state in room_to_state_updates.items(): logger.debug("Updating room_stats_state for %s: %s", room_id, state) - yield self.store.update_room_state(room_id, state) + await self.store.update_room_state(room_id, state) return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 00718d7f2d..6bdb24baff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1370,7 +1370,7 @@ class SyncHandler(object): sync_result_builder.now_token = now_token # We check up front if anything has changed, if it hasn't then there is - # no point in going futher. + # no point in going further. since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8363d887a9..8b24a73319 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker: self.hs = hs self.store = hs.get_datastore() - @defer.inlineCallbacks - def _check_threepid(self, medium, authdict): + async def _check_threepid(self, medium, authdict): if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) @@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) - threepid = yield identity_handler.threepid_from_creds( + threepid = await identity_handler.threepid_from_creds( self.hs.config.account_threepid_delegate_msisdn, threepid_creds ) elif medium == "email": if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email - threepid = yield identity_handler.threepid_from_creds( + threepid = await identity_handler.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: threepid = None - row = yield self.store.get_threepid_validation_session( + row = await self.store.get_threepid_validation_session( medium, threepid_creds["client_secret"], sid=threepid_creds["sid"], @@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker: } # Valid threepid returned, delete from the db - yield self.store.delete_threepid_session(threepid_creds["sid"]) + await self.store.delete_threepid_session(threepid_creds["sid"]) else: raise SynapseError( 400, "Email address verification is not enabled on this homeserver" @@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec ) def check_auth(self, authdict, clientip): - return self._check_threepid("email", authdict) + return defer.ensureDeferred(self._check_threepid("email", authdict)) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): @@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return bool(self.hs.config.account_threepid_delegate_msisdn) def check_auth(self, authdict, clientip): - return self._check_threepid("msisdn", authdict) + return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) INTERACTIVE_AUTH_CHECKERS = [ diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 722760c59d..12423b909a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -17,14 +17,11 @@ import logging from six import iteritems, iterkeys -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo -from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -103,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler): if self._is_processing: return - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False self._is_processing = True run_as_background_process("user_directory.notify_new_event", process) - @defer.inlineCallbacks - def handle_local_profile_change(self, user_id, profile): + async def handle_local_profile_change(self, user_id, profile): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = yield self.store.is_support_user(user_id) + is_support = await self.store.is_support_user(user_id) # Support users are for diagnostics and should not appear in the user directory. if not is_support: - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - @defer.inlineCallbacks - def handle_user_deactivated(self, user_id): + async def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_user_directory_stream_pos() + self.pos = await self.store.get_user_directory_stream_pos() # If still None then the initial background update hasn't happened yet if self.pos is None: @@ -155,12 +148,12 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) logger.debug("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + await self._handle_deltas(deltas) self.pos = max_pos @@ -169,10 +162,9 @@ class UserDirectoryHandler(StateDeltasHandler): max_pos ) - yield self.store.update_user_directory_stream_pos(max_pos) + await self.store.update_user_directory_stream_pos(max_pos) - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process """ for delta in deltas: @@ -187,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler): # For join rule and visibility changes we need to check if the room # may have become public or not and add/remove the users in said room if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): - yield self._handle_room_publicity_change( + await self._handle_room_publicity_change( room_id, prev_event_id, event_id, typ ) elif typ == EventTypes.Member: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="membership", @@ -201,7 +193,7 @@ class UserDirectoryHandler(StateDeltasHandler): if change is False: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room - is_in_room = yield self.store.is_host_joined( + is_in_room = await self.store.is_host_joined( room_id, self.server_name ) if not is_in_room: @@ -209,40 +201,41 @@ class UserDirectoryHandler(StateDeltasHandler): # Fetch all the users that we marked as being in user # directory due to being in the room and then check if # need to remove those users or not - user_ids = yield self.store.get_users_in_dir_due_to_room( + user_ids = await self.store.get_users_in_dir_due_to_room( room_id ) for user_id in user_ids: - yield self._handle_remove_user(room_id, user_id) + await self._handle_remove_user(room_id, user_id) return else: logger.debug("Server is still in room: %r", room_id) - is_support = yield self.store.is_support_user(state_key) + is_support = await self.store.is_support_user(state_key) if not is_support: if change is None: # Handle any profile changes - yield self._handle_profile_change( + await self._handle_profile_change( state_key, room_id, prev_event_id, event_id ) continue if change: # The user joined - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), ) - yield self._handle_new_user(room_id, state_key, profile) + await self._handle_new_user(room_id, state_key, profile) else: # The user left - yield self._handle_remove_user(room_id, state_key) + await self._handle_remove_user(room_id, state_key) else: logger.debug("Ignoring irrelevant type: %r", typ) - @defer.inlineCallbacks - def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + async def _handle_room_publicity_change( + self, room_id, prev_event_id, event_id, typ + ): """Handle a room having potentially changed from/to world_readable/publically joinable. @@ -255,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Handling change for %s: %s", typ, room_id) if typ == EventTypes.RoomHistoryVisibility: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="history_visibility", public_value="world_readable", ) elif typ == EventTypes.JoinRules: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="join_rule", @@ -278,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler): # There's been a change to or from being world readable. - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) @@ -293,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. for user_id in iterkeys(users_with_profile): - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. # NOTE: this is not the most efficient method, as handle_new_user sets @@ -306,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler): # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id, profile in iteritems(users_with_profile): - yield self._handle_new_user(room_id, user_id, profile) - - @defer.inlineCallbacks - def _handle_local_user(self, user_id): - """Adds a new local roomless user into the user_directory_search table. - Used to populate up the user index when we have an - user_directory_search_all_users specified. - """ - logger.debug("Adding new local user to dir, %r", user_id) - - profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) - - row = yield self.store.get_user_in_directory(user_id) - if not row: - yield self.store.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + await self._handle_new_user(room_id, user_id, profile) - @defer.inlineCallbacks - def _handle_new_user(self, room_id, user_id, profile): + async def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: @@ -334,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler): """ logger.debug("Adding new user to dir, %r", user_id) - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) # Now we update users who share rooms with users. - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) if is_public: - yield self.store.add_users_in_public_rooms(room_id, (user_id,)) + await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: to_insert = set() @@ -376,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler): to_insert.add((other_user_id, user_id)) if to_insert: - yield self.store.add_users_who_share_private_room(room_id, to_insert) + await self.store.add_users_who_share_private_room(room_id, to_insert) - @defer.inlineCallbacks - def _handle_remove_user(self, room_id, user_id): + async def _handle_remove_user(self, room_id, user_id): """Called when we might need to remove user from directory Args: @@ -389,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Removing user %r", user_id) # Remove user from sharing tables - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id) + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) if len(rooms_user_is_in) == 0: - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): """Check member event changes for any profile changes and update the database if there are. """ if not prev_event_id or not event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - event = yield self.store.get_event(event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not prev_event or not event: return @@ -421,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler): new_avatar = event.content.get("avatar_url") if prev_name != new_name or prev_avatar != new_avatar: - yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/synapse/http/server.py b/synapse/http/server.py index 9cc2e2e154..2487a72171 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -21,13 +21,15 @@ import logging import types import urllib from io import BytesIO +from typing import Awaitable, Callable, TypeVar, Union +import jinja2 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from twisted.internet import defer from twisted.python import failure from twisted.web import resource -from twisted.web.server import NOT_DONE_YET +from twisted.web.server import NOT_DONE_YET, Request from twisted.web.static import NoRangeStaticProducer from twisted.web.util import redirectTo @@ -40,6 +42,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet from synapse.util.caches import intern_dict @@ -130,7 +133,12 @@ def wrap_json_request_handler(h): return wrap_async_request_handler(wrapped_request_handler) -def wrap_html_request_handler(h): +TV = TypeVar("TV") + + +def wrap_html_request_handler( + h: Callable[[TV, SynapseRequest], Awaitable] +) -> Callable[[TV, SynapseRequest], Awaitable[None]]: """Wraps a request handler method with exception handling. Also does the wrapping with request.processing as per wrap_async_request_handler. @@ -141,20 +149,26 @@ def wrap_html_request_handler(h): async def wrapped_request_handler(self, request): try: - return await h(self, request) + await h(self, request) except Exception: f = failure.Failure() - return _return_html_error(f, request) + return_html_error(f, request, HTML_ERROR_TEMPLATE) return wrap_async_request_handler(wrapped_request_handler) -def _return_html_error(f, request): - """Sends an HTML error page corresponding to the given failure +def return_html_error( + f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], +) -> None: + """Sends an HTML error page corresponding to the given failure. + + Handles RedirectException and other CodeMessageExceptions (such as SynapseError) Args: - f (twisted.python.failure.Failure): - request (twisted.web.server.Request): + f: the error to report + request: the failing request + error_template: the HTML template. Can be either a string (with `{code}`, + `{msg}` placeholders), or a jinja2 template """ if f.check(CodeMessageException): cme = f.value @@ -174,7 +188,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http.client.INTERNAL_SERVER_ERROR + code = http.HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -183,11 +197,16 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8") + if isinstance(error_template, str): + body = error_template.format(code=code, msg=html.escape(msg)) + else: + body = error_template.render(code=code, msg=msg) + + body_bytes = body.encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % (len(body),)) - request.write(body) + request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),)) + request.write(body_bytes) finish_request(request) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 8449ef82f7..13785038ad 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -17,16 +17,18 @@ import logging import threading from asyncio import iscoroutine from functools import wraps -from typing import Dict, Set +from typing import TYPE_CHECKING, Dict, Optional, Set -import six - -from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily +from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext +if TYPE_CHECKING: + import resource + + logger = logging.getLogger(__name__) @@ -36,6 +38,12 @@ _background_process_start_count = Counter( ["name"], ) +_background_process_in_flight_count = Gauge( + "synapse_background_process_in_flight_count", + "Number of background processes in flight", + labelnames=["name"], +) + # we set registry=None in all of these to stop them getting registered with # the default registry. Instead we collect them all via the CustomCollector, # which ensures that we can update them before they are collected. @@ -83,13 +91,17 @@ _background_process_db_sched_duration = Counter( # it's much simpler to do so than to try to combine them.) _background_process_counts = {} # type: Dict[str, int] -# map from description to the currently running background processes. +# Set of all running background processes that became active active since the +# last time metrics were scraped (i.e. background processes that performed some +# work since the last scrape.) # -# it's kept as a dict of sets rather than a big set so that we can keep track -# of process descriptions that no longer have any active processes. -_background_processes = {} # type: Dict[str, Set[_BackgroundProcess]] +# We do it like this to handle the case where we have a large number of +# background processes stacking up behind a lock or linearizer, where we then +# only need to iterate over and update metrics for the process that have +# actually been active and can ignore the idle ones. +_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] -# A lock that covers the above dicts +# A lock that covers the above set and dict _bg_metrics_lock = threading.Lock() @@ -101,25 +113,16 @@ class _Collector(object): """ def collect(self): - background_process_in_flight_count = GaugeMetricFamily( - "synapse_background_process_in_flight_count", - "Number of background processes in flight", - labels=["name"], - ) + global _background_processes_active_since_last_scrape - # We copy the dict so that it doesn't change from underneath us. - # We also copy the process lists as that can also change + # We swap out the _background_processes set with an empty one so that + # we can safely iterate over the set without holding the lock. with _bg_metrics_lock: - _background_processes_copy = { - k: list(v) for k, v in six.iteritems(_background_processes) - } + _background_processes_copy = _background_processes_active_since_last_scrape + _background_processes_active_since_last_scrape = set() - for desc, processes in six.iteritems(_background_processes_copy): - background_process_in_flight_count.add_metric((desc,), len(processes)) - for process in processes: - process.update_metrics() - - yield background_process_in_flight_count + for process in _background_processes_copy: + process.update_metrics() # now we need to run collect() over each of the static Counters, and # yield each metric they return. @@ -191,13 +194,10 @@ def run_as_background_process(desc, func, *args, **kwargs): _background_process_counts[desc] = count + 1 _background_process_start_count.labels(desc).inc() + _background_process_in_flight_count.labels(desc).inc() - with LoggingContext(desc) as context: + with BackgroundProcessLoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) - proc = _BackgroundProcess(desc, context) - - with _bg_metrics_lock: - _background_processes.setdefault(desc, set()).add(proc) try: result = func(*args, **kwargs) @@ -214,10 +214,7 @@ def run_as_background_process(desc, func, *args, **kwargs): except Exception: logger.exception("Background process '%s' threw an exception", desc) finally: - proc.update_metrics() - - with _bg_metrics_lock: - _background_processes[desc].remove(proc) + _background_process_in_flight_count.labels(desc).dec() with PreserveLoggingContext(): return run() @@ -238,3 +235,42 @@ def wrap_as_background_process(desc): return wrap_as_background_process_inner_2 return wrap_as_background_process_inner + + +class BackgroundProcessLoggingContext(LoggingContext): + """A logging context that tracks in flight metrics for background + processes. + """ + + __slots__ = ["_proc"] + + def __init__(self, name: str): + super().__init__(name) + + self._proc = _BackgroundProcess(name, self) + + def start(self, rusage: "Optional[resource._RUsage]"): + """Log context has started running (again). + """ + + super().start(rusage) + + # We've become active again so we make sure we're in the list of active + # procs. (Note that "start" here means we've become active, as opposed + # to starting for the first time.) + with _bg_metrics_lock: + _background_processes_active_since_last_scrape.add(self._proc) + + def __exit__(self, type, value, traceback) -> None: + """Log context has finished. + """ + + super().__exit__(type, value, traceback) + + # The background process has finished. We explictly remove and manually + # update the metrics here so that if nothing is scraping metrics the set + # doesn't infinitely grow. + with _bg_metrics_lock: + _background_processes_active_since_last_scrape.discard(self._proc) + + self._proc.update_metrics() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d678c0eb9b..ecdf1ad69f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -128,8 +128,12 @@ class ModuleApi(object): Returns: Deferred[str]: user_id """ - return self._hs.get_registration_handler().register_user( - localpart=localpart, default_display_name=displayname, bind_emails=emails + return defer.ensureDeferred( + self._hs.get_registration_handler().register_user( + localpart=localpart, + default_display_name=displayname, + bind_emails=emails, + ) ) def register_device(self, user_id, device_id=None, initial_display_name=None): diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 2a4f5c7cfd..9db6c62bc7 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -24,7 +24,13 @@ from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id" + db_conn, + "account_data", + "stream_id", + extra_tables=[ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], ) super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 508ad1b720..df29732f51 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -171,7 +171,7 @@ class ReplicationDataHandler: pass else: # The list is sorted by position so we don't need to continue - # checking any futher entries in the list. + # checking any further entries in the list. index_of_first_deferred_not_called = idx break diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d42aaff055..4acefc8a96 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -600,8 +600,14 @@ class AccountDataStream(Stream): for stream_id, user_id, room_id, account_data_type in room_results ) - # we need to return a sorted list, so merge them together. - updates = list(heapq.merge(room_rows, global_rows)) + # We need to return a sorted list, so merge them together. + # + # Note: We order only by the stream ID to work around a bug where the + # same stream ID could appear in both `global_rows` and `room_rows`, + # leading to a comparison between the data tuples. The comparison could + # fail due to attempting to compare the `room_id` which results in a + # `TypeError` from comparing a `str` vs `None`. + updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) return updates, to_token, limited diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6b85148a32..9eda592de9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -26,6 +26,11 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, historical_admin_path_patterns, ) +from synapse.rest.admin.devices import ( + DeleteDevicesRestServlet, + DeviceRestServlet, + DevicesRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -202,6 +207,9 @@ def register_servlets(hs, http_server): UserAdminServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index a96f75ce26..d82eaf5e38 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -15,7 +15,11 @@ import re +import twisted.web.server + +import synapse.api.auth from synapse.api.errors import AuthError +from synapse.types import UserID def historical_admin_path_patterns(path_regex): @@ -55,41 +59,32 @@ def admin_patterns(path_regex: str): return patterns -async def assert_requester_is_admin(auth, request): +async def assert_requester_is_admin( + auth: synapse.api.auth.Auth, request: twisted.web.server.Request +) -> None: """Verify that the requester is an admin user - WARNING: MAKE SURE YOU YIELD ON THE RESULT! - Args: - auth (synapse.api.auth.Auth): - request (twisted.web.server.Request): incoming request - - Returns: - Deferred + auth: api.auth.Auth singleton + request: incoming request Raises: - AuthError if the requester is not an admin + AuthError if the requester is not a server admin """ requester = await auth.get_user_by_req(request) await assert_user_is_admin(auth, requester.user) -async def assert_user_is_admin(auth, user_id): +async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: """Verify that the given user is an admin user - WARNING: MAKE SURE YOU YIELD ON THE RESULT! - Args: - auth (synapse.api.auth.Auth): - user_id (UserID): - - Returns: - Deferred + auth: api.auth.Auth singleton + user_id: user to check Raises: - AuthError if the user is not an admin + AuthError if the user is not a server admin """ - is_admin = await auth.is_server_admin(user_id) if not is_admin: raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py new file mode 100644 index 0000000000..8d32677339 --- /dev/null +++ b/synapse/rest/admin/devices.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from synapse.api.errors import NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin._base import assert_requester_is_admin +from synapse.types import UserID + +logger = logging.getLogger(__name__) + + +class DeviceRestServlet(RestServlet): + """ + Get, update or delete the given user's device + """ + + PATTERNS = ( + re.compile( + "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$" + ), + ) + + def __init__(self, hs): + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + device = await self.device_handler.get_device( + target_user.to_string(), device_id + ) + return 200, device + + async def on_DELETE(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + await self.device_handler.delete_device(target_user.to_string(), device_id) + return 200, {} + + async def on_PUT(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=True) + await self.device_handler.update_device( + target_user.to_string(), device_id, body + ) + return 200, {} + + +class DevicesRestServlet(RestServlet): + """ + Retrieve the given user's devices + """ + + PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + devices = await self.device_handler.get_devices_by_user(target_user.to_string()) + return 200, {"devices": devices} + + +class DeleteDevicesRestServlet(RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. + """ + + PATTERNS = ( + re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"), + ) + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=False) + assert_params_in_dict(body, ["devices"]) + + await self.device_handler.delete_devices( + target_user.to_string(), body["devices"] + ) + return 200, {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e7f6928c85..fefc8f71fa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -142,6 +142,7 @@ class UserRestServletV2(RestServlet): self.set_password_handler = hs.get_set_password_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler() self.registration_handler = hs.get_registration_handler() + self.pusher_pool = hs.get_pusherpool() async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) @@ -269,6 +270,7 @@ class UserRestServletV2(RestServlet): admin=bool(admin), default_display_name=displayname, user_type=user_type, + by_admin=True, ) if "threepids" in body: @@ -281,6 +283,21 @@ class UserRestServletV2(RestServlet): await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + ): + await self.pusher_pool.add_pusher( + user_id=user_id, + access_token=None, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( @@ -416,6 +433,7 @@ class UserRegisterServlet(RestServlet): password_hash=password_hash, admin=bool(admin), user_type=user_type, + by_admin=True, ) result = await register._create_registration_details(user_id, body) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d89b2e5532..dceb2792fa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -87,11 +87,22 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -99,25 +110,20 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.cas_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # we advertise CAS for backwards compat, though MSC1721 renamed it # to SSO. flows.append({"type": LoginRestServlet.CAS_TYPE}) + if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) # While its valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the - # CAS login flow. + # SSO login flow. # Generally we don't want to advertise login flows that clients # don't know how to implement, since they (currently) will always # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - elif self.saml2_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - elif self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) flows.extend( ({"type": t} for t in self.auth_handler.get_supported_login_types()) @@ -129,13 +135,7 @@ class LoginRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - self._address_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -203,13 +203,7 @@ class LoginRestServlet(RestServlet): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -243,13 +237,7 @@ class LoginRestServlet(RestServlet): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -268,11 +256,7 @@ class LoginRestServlet(RestServlet): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, + qualified_user_id.lower(), update=False ) try: @@ -284,13 +268,7 @@ class LoginRestServlet(RestServlet): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -299,7 +277,7 @@ class LoginRestServlet(RestServlet): return result async def _complete_login( - self, user_id, login_submission, callback=None, create_non_existant_users=False + self, user_id, login_submission, callback=None, create_non_existent_users=False ): """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -312,7 +290,7 @@ class LoginRestServlet(RestServlet): user_id (str): ID of the user to register. login_submission (dict): Dictionary of login information. callback (func|None): Callback function to run after registration. - create_non_existant_users (bool): Whether to create the user if + create_non_existent_users (bool): Whether to create the user if they don't exist. Defaults to False. Returns: @@ -323,20 +301,15 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) - if create_non_existant_users: - user_id = await self.auth_handler.check_user_exists(user_id) - if not user_id: - user_id = await self.registration_handler.register_user( + if create_non_existent_users: + canonical_uid = await self.auth_handler.check_user_exists(user_id) + if not canonical_uid: + canonical_uid = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) + user_id = canonical_uid device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") @@ -391,7 +364,7 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existant_users=True + user_id, login_submission, create_non_existent_users=True ) return result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 8f41a3edbf..24bb090822 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -42,7 +42,7 @@ class KeyUploadServlet(RestServlet): "device_id": "<device_id>", "valid_until_ts": <millisecond_timestamp>, "algorithms": [ - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ] "keys": { "<algorithm>:<device_id>": "<key_base64>", @@ -124,7 +124,7 @@ class KeyQueryServlet(RestServlet): "device_id": "<device_id>", // Duplicated to be signed "valid_until_ts": <millisecond_timestamp>, "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ], "keys": { // Must include a ed25519 signing key "<algorithm>:<key_id>": "<key_base64>", @@ -285,8 +285,8 @@ class SignaturesUploadServlet(RestServlet): "user_id": "<user_id>", "device_id": "<device_id>", "algorithms": [ - "m.olm.curve25519-aes-sha256", - "m.megolm.v1.aes-sha" + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" ], "keys": { "<algorithm>:<device_id>": "<key_base64>", diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c35ee81448..5f7c7d0081 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -28,7 +28,6 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -394,20 +393,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c99250c2ee..b1999d051b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -49,24 +49,10 @@ class VersionsRestServlet(RestServlet): "r0.3.0", "r0.4.0", "r0.5.0", + "r0.6.0", ], # as per MSC1497: "unstable_features": { - # as per MSC2190, as amended by MSC2264 - # to be removed in r0.6.0 - # "m.id_access_token": True, - # Advertise to clients that they need not include an `id_server` - # parameter during registration or password reset, as Synapse now decides - # itself which identity server to use (or none at all). - # - # This is also used by a client when they wish to bind a 3PID to their - # account, but not bind it to an identity server, the endpoint for which - # also requires `id_server`. If the homeserver is handling 3PID - # verification itself, there is no need to ask the user for `id_server` to - # be supplied. - # "m.require_identity_server": False, - # as per MSC2290 - # "m.separate_add_and_bind": True, # Implements support for label-based filtering as described in # MSC2326. "org.matrix.label_based_filtering": True, diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index a545c13db7..75e58043b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,12 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.python import failure -from synapse.http.server import ( - DirectServeResource, - finish_request, - wrap_html_request_handler, -) +from synapse.api.errors import SynapseError +from synapse.http.server import DirectServeResource, return_html_error class SAML2ResponseResource(DirectServeResource): @@ -28,20 +26,22 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() - self._error_html_content = hs.config.saml2_error_html_content self._saml_handler = hs.get_saml_handler() + self._error_html_template = hs.config.saml2.saml2_error_html_template async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - request.setResponseCode(400) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),)) - request.write(self._error_html_content.encode("utf8")) - finish_request(request) + f = failure.Failure( + SynapseError(400, "Unexpected GET request on /saml2/authn_response") + ) + return_html_error(f, request, self._error_html_template) - @wrap_html_request_handler async def _async_render_POST(self, request): - return await self._saml_handler.handle_saml_response(request) + try: + await self._saml_handler.handle_saml_response(request) + except Exception: + f = failure.Failure() + return_html_error(f, request, self._error_html_template) diff --git a/synapse/server.py b/synapse/server.py index 6a4eb8294f..d14b6b722c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -244,9 +244,12 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() + + self.registration_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) self.datastores = None @@ -316,15 +319,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter - - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): - return self.admin_redaction_ratelimiter - def build_federation_client(self): return FederationClient(self) diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 611df36d3f..ba8048b23f 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -4,30 +4,41 @@ window.matrixLogin = { serverAcceptsSso: false, }; -var title_pre_auth = "Log in with one of the following methods"; -var title_post_auth = "Logging in..."; - -var submitPassword = function(user, pwd) { - console.log("Logging in with password..."); - set_title(title_post_auth); - var data = { - type: "m.login.password", - user: user, - password: pwd, - }; - $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - matrixLogin.onLogin(response); - }).fail(errorFunc); -}; +// Titles get updated through the process to give users feedback. +var TITLE_PRE_AUTH = "Log in with one of the following methods"; +var TITLE_POST_AUTH = "Logging in..."; + +// The cookie used to store the original query parameters when using SSO. +var COOKIE_KEY = "synapse_login_fallback_qs"; + +/* + * Submit a login request. + * + * type: The login type as a string (e.g. "m.login.foo"). + * data: An object of data specific to the login type. + * extra: (Optional) An object to search for extra information to send with the + * login request, e.g. device_id. + * callback: (Optional) Function to call on successful login. + */ +var submitLogin = function(type, data, extra, callback) { + console.log("Logging in with " + type); + set_title(TITLE_POST_AUTH); + + // Add the login type. + data.type = type; + + // Add the device information, if it was provided. + if (extra.device_id) { + data.device_id = extra.device_id; + } + if (extra.initial_device_display_name) { + data.initial_device_display_name = extra.initial_device_display_name; + } -var submitToken = function(loginToken) { - console.log("Logging in with login token..."); - set_title(title_post_auth); - var data = { - type: "m.login.token", - token: loginToken - }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { + if (callback) { + callback(); + } matrixLogin.onLogin(response); }).fail(errorFunc); }; @@ -50,12 +61,19 @@ var setFeedbackString = function(text) { }; var show_login = function(inhibit_redirect) { + // Set the redirect to come back to this page, a login token will get added + // and handled after the redirect. var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - // If inhibit_redirect is false, and SSO is the only supported login method, we can - // redirect straight to the SSO page + // If inhibit_redirect is false, and SSO is the only supported login method, + // we can redirect straight to the SSO page. if (matrixLogin.serverAcceptsSso) { + // Before submitting SSO, set the current query parameters into a cookie + // for retrieval later. + var qs = parseQsFromUrl(); + setCookie(COOKIE_KEY, JSON.stringify(qs)); + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { $("#sso_form").submit(); return; @@ -73,7 +91,7 @@ var show_login = function(inhibit_redirect) { $("#no_login_types").show(); } - set_title(title_pre_auth); + set_title(TITLE_PRE_AUTH); $("#loading").hide(); }; @@ -123,7 +141,10 @@ matrixLogin.password_login = function() { setFeedbackString(""); show_spinner(); - submitPassword(user, pwd); + submitLogin( + "m.login.password", + {user: user, password: pwd}, + parseQsFromUrl()); }; matrixLogin.onLogin = function(response) { @@ -131,7 +152,16 @@ matrixLogin.onLogin = function(response) { console.warn("onLogin - This function should be replaced to proceed."); }; -var parseQsFromUrl = function(query) { +/* + * Process the query parameters from the current URL into an object. + */ +var parseQsFromUrl = function() { + var pos = window.location.href.indexOf("?"); + if (pos == -1) { + return {}; + } + var query = window.location.href.substr(pos + 1); + var result = {}; query.split("&").forEach(function(part) { var item = part.split("="); @@ -141,25 +171,80 @@ var parseQsFromUrl = function(query) { if (val) { val = decodeURIComponent(val); } - result[key] = val + result[key] = val; }); return result; }; +/* + * Process the cookies and return an object. + */ +var parseCookies = function() { + var allCookies = document.cookie; + var result = {}; + allCookies.split(";").forEach(function(part) { + var item = part.split("="); + // Cookies might have arbitrary whitespace between them. + var key = item[0].trim(); + // You can end up with a broken cookie that doesn't have an equals sign + // in it. Set to an empty value. + var val = (item[1] || "").trim(); + // Values might be URI encoded. + if (val) { + val = decodeURIComponent(val); + } + result[key] = val; + }); + return result; +}; + +/* + * Set a cookie that is valid for 1 hour. + */ +var setCookie = function(key, value) { + // The maximum age is set in seconds. + var maxAge = 60 * 60; + // Set the cookie, this defaults to the current domain and path. + document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax"; +}; + +/* + * Removes a cookie by key. + */ +var deleteCookie = function(key) { + // Delete a cookie by setting the expiration to 0. (Note that the value + // doesn't matter.) + document.cookie = key + "=deleted;expires=0"; +}; + +/* + * Submits the login token if one is found in the query parameters. Returns a + * boolean of whether the login token was found or not. + */ var try_token = function() { - var pos = window.location.href.indexOf("?"); - if (pos == -1) { - return false; - } - var qs = parseQsFromUrl(window.location.href.substr(pos+1)); + // Check if the login token is in the query parameters. + var qs = parseQsFromUrl(); var loginToken = qs.loginToken; - if (!loginToken) { return false; } - submitToken(loginToken); + // Retrieve the original query parameters (from before the SSO redirect). + // They are stored as JSON in a cookie. + var cookies = parseCookies(); + var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}") + + // If the login is successful, delete the cookie. + var callback = function() { + deleteCookie(COOKIE_KEY); + } + + submitLogin( + "m.login.token", + {token: loginToken}, + original_query_params, + callback); return true; }; diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index f9eef1b78e..b58f04d00d 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -297,7 +297,13 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" + db_conn, + "account_data_max_stream_id", + "stream_id", + extra_tables=[ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], ) super(AccountDataStore, self).__init__(database, db_conn, hs) @@ -387,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore): # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. + # + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) @@ -405,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore): next_id(int): The the revision to advance to. """ + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. + def _update(txn): update_max_id_sql = ( "UPDATE account_data_max_stream_id" diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 1310d39069..e459cf49a0 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List from twisted.internet import defer @@ -77,20 +78,19 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return self.db.runInteraction("count_users_by_service", _count_users_by_service) - @defer.inlineCallbacks - def get_registered_reserved_users(self): - """Of the reserved threepids defined in config, which are associated - with registered users? + async def get_registered_reserved_users(self) -> List[str]: + """Of the reserved threepids defined in config, retrieve those that are associated + with registered users Returns: - Defered[list]: Real reserved users + User IDs of actual users that are reserved """ users = [] for tp in self.hs.config.mau_limits_reserved_threepids[ : self.hs.config.max_mau_value ]: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] ) if user_id: @@ -171,13 +171,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): else: logger.warning("mau limit reserved threepid %s not found in db" % tp) - @defer.inlineCallbacks - def reap_monthly_active_users(self): + async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. - - Returns: - Deferred[] """ def _reap_users(txn, reserved_users): @@ -249,8 +245,8 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) - reserved_users = yield self.get_registered_reserved_users() - yield self.db.runInteraction( + reserved_users = await self.get_registered_reserved_users() + await self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) @@ -261,6 +257,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): Args: user_id (str): user to add/update + + Returns: + Deferred """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 0d932a0672..cebdcd409f 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -391,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self.db.simple_delete_txn( + self.db.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ @@ -399,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore): "receipt_type": receipt_type, "user_id": user_id, }, - ) - - self.db.simple_insert_txn( - txn, - table="receipts_linearized", values={ "stream_id": stream_id, - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, "event_id": event_id, "data": json.dumps(data), }, + # receipts_linearized has a unique constraint on + # (user_id, room_id, receipt_type), so no need to lock + lock=False, ) if receipt_type == "m.read" and stream_ordering is not None: diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 1f1a7b4e36..ab70776977 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import List +from typing import List, Optional from six import iterkeys @@ -423,7 +423,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) return res - @cachedInlineCallbacks() + @cached() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -433,10 +433,9 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.db.runInteraction( + return self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) - return res def is_real_user_txn(self, txn, user_id): res = self.db.simple_select_one_onecol_txn( @@ -597,18 +596,17 @@ class RegistrationWorkerStore(SQLBaseStore): ) ) - @defer.inlineCallbacks - def get_user_id_by_threepid(self, medium, address): + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid Args: - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - Deferred[str|None]: user id or None if no user id/threepid mapping exists + The user ID or None if no user id/threepid mapping exists """ - user_id = yield self.db.runInteraction( + user_id = await self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -1074,7 +1072,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Args: user_id (str): The desired user ID to register. - password_hash (str): Optional. The password hash for this user. + password_hash (str|None): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, @@ -1088,6 +1086,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Raises: StoreError if the user_id could not be registered. + + Returns: + Deferred """ return self.db.runInteraction( "register_user", diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 2aa1bafd48..4219018302 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -233,6 +233,9 @@ class TagsStore(TagsWorkerStore): self._account_data_stream_cache.entity_has_changed, user_id, next_id ) + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. update_max_id_sql = ( "UPDATE account_data_max_stream_id" " SET stream_id = ?" diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py index e8edaf9f7b..ff000bc9ec 100644 --- a/synapse/storage/data_stores/state/bg_updates.py +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -109,20 +109,20 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state - ) + ) %s + ORDER BY type, state_key, state_group DESC """ for group in groups: args = [group] args.extend(where_args) - txn.execute(sql + where_clause, args) + txn.execute(sql % (where_clause,), args) for row in txn: typ, state_key, event_id = row key = (typ, state_key) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 9afc145340..9cc3b51fe6 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) # schema files, so the users will be informed on server restarts. # XXX: If you're about to bump this to 59 (or higher) please create an update # that drops the unused `cache_invalidation_stream` table, as per #7436! +# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -366,9 +367,8 @@ def _upgrade_existing_database( if duplicates: # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( - "Found multiple delta files with the same name in v%d: %s", - v, - duplicates, + "Found multiple delta files with the same name in v%d: %s" + % (v, duplicates,) ) # We sort to ensure that we apply the delta files in a consistent diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 581dffd8a0..f7af2bca7f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -225,6 +225,18 @@ class Linearizer(object): {} ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + def is_queued(self, key) -> bool: + """Checks whether there is a process queued up waiting + """ + entry = self.key_to_defer.get(key) + if not entry: + # No entry so nothing is waiting. + return False + + # There are waiting deferreds only in the OrderedDict of deferreds is + # non-empty. + return bool(entry[1]) + def queue(self, key): # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce3..e5efdfcd02 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ class FederationRateLimiter(object): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: |